From 91acc3ea75d219ad07ed2b106e7b8bdcb01516dd Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Fri, 4 Mar 2022 17:17:36 +0900 Subject: [PATCH 001/424] Fix entity linker batching (#9669) * Partial fix of entity linker batching * Add import * Better name * Add `use_gold_ents` option, docs * Change to v2, create stub v1, update docs etc. * Fix error type Honestly no idea what the right type to use here is. ConfigValidationError seems wrong. Maybe a NotImplementedError? * Make mypy happy * Add hacky fix for init issue * Add legacy pipeline entity linker * Fix references to class name * Add __init__.py for legacy * Attempted fix for loss issue * Remove placeholder V1 * formatting * slightly more interesting train data * Handle batches with no usable examples This adds a test for batches that have docs but not entities, and a check in the component that detects such cases and skips the update step as thought the batch were empty. * Remove todo about data verification Check for empty data was moved further up so this should be OK now - the case in question shouldn't be possible. * Fix gradient calculation The model doesn't know which entities are not in the kb, so it generates embeddings for the context of all of them. However, the loss does know which entities aren't in the kb, and it ignores them, as there's no sensible gradient. This has the issue that the gradient will not be calculated for some of the input embeddings, which causes a dimension mismatch in backprop. That should have caused a clear error, but with numpyops it was causing nans to happen, which is another problem that should be addressed separately. This commit changes the loss to give a zero gradient for entities not in the kb. * add failing test for v1 EL legacy architecture * Add nasty but simple working check for legacy arch * Clarify why init hack works the way it does * Clarify use_gold_ents use case * Fix use gold ents related handling * Add tests for no gold ents and fix other tests * Use aligned ents function (not working) This doesn't actually work because the "aligned" ents are gold-only. But if I have a different function that returns the intersection, *then* this will work as desired. * Use proper matching ent check This changes the process when gold ents are not used so that the intersection of ents in the pred and gold is used. * Move get_matching_ents to Example * Use model attribute to check for legacy arch * Rename flag * bump spacy-legacy to lower 3.0.9 Co-authored-by: svlandeg --- requirements.txt | 2 +- setup.cfg | 2 +- spacy/cli/templates/quickstart_training.jinja | 4 +- spacy/ml/extract_spans.py | 2 +- spacy/ml/models/entity_linker.py | 60 ++- spacy/pipeline/entity_linker.py | 135 ++++-- spacy/pipeline/legacy/__init__.py | 3 + spacy/pipeline/legacy/entity_linker.py | 427 ++++++++++++++++++ spacy/tests/pipeline/test_entity_linker.py | 151 ++++++- spacy/training/example.pyx | 23 + website/docs/api/architectures.md | 4 +- website/docs/api/entitylinker.md | 1 + 12 files changed, 765 insertions(+), 49 deletions(-) create mode 100644 spacy/pipeline/legacy/__init__.py create mode 100644 spacy/pipeline/legacy/entity_linker.py diff --git a/requirements.txt b/requirements.txt index ca4099be5..b8970f686 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ # Our libraries -spacy-legacy>=3.0.8,<3.1.0 +spacy-legacy>=3.0.9,<3.1.0 spacy-loggers>=1.0.0,<2.0.0 cymem>=2.0.2,<2.1.0 preshed>=3.0.2,<3.1.0 diff --git a/setup.cfg b/setup.cfg index 586a044ff..ed3bf63ce 100644 --- a/setup.cfg +++ b/setup.cfg @@ -41,7 +41,7 @@ setup_requires = thinc>=8.0.12,<8.1.0 install_requires = # Our libraries - spacy-legacy>=3.0.8,<3.1.0 + spacy-legacy>=3.0.9,<3.1.0 spacy-loggers>=1.0.0,<2.0.0 murmurhash>=0.28.0,<1.1.0 cymem>=2.0.2,<2.1.0 diff --git a/spacy/cli/templates/quickstart_training.jinja b/spacy/cli/templates/quickstart_training.jinja index fb79a4f60..da533b767 100644 --- a/spacy/cli/templates/quickstart_training.jinja +++ b/spacy/cli/templates/quickstart_training.jinja @@ -131,7 +131,7 @@ incl_context = true incl_prior = true [components.entity_linker.model] -@architectures = "spacy.EntityLinker.v1" +@architectures = "spacy.EntityLinker.v2" nO = null [components.entity_linker.model.tok2vec] @@ -303,7 +303,7 @@ incl_context = true incl_prior = true [components.entity_linker.model] -@architectures = "spacy.EntityLinker.v1" +@architectures = "spacy.EntityLinker.v2" nO = null [components.entity_linker.model.tok2vec] diff --git a/spacy/ml/extract_spans.py b/spacy/ml/extract_spans.py index edc86ff9c..d5e9bc07c 100644 --- a/spacy/ml/extract_spans.py +++ b/spacy/ml/extract_spans.py @@ -63,4 +63,4 @@ def _get_span_indices(ops, spans: Ragged, lengths: Ints1d) -> Ints1d: def _ensure_cpu(spans: Ragged, lengths: Ints1d) -> Tuple[Ragged, Ints1d]: - return (Ragged(to_numpy(spans.dataXd), to_numpy(spans.lengths)), to_numpy(lengths)) + return Ragged(to_numpy(spans.dataXd), to_numpy(spans.lengths)), to_numpy(lengths) diff --git a/spacy/ml/models/entity_linker.py b/spacy/ml/models/entity_linker.py index 831fee90f..0149bea89 100644 --- a/spacy/ml/models/entity_linker.py +++ b/spacy/ml/models/entity_linker.py @@ -1,34 +1,82 @@ from pathlib import Path -from typing import Optional, Callable, Iterable, List +from typing import Optional, Callable, Iterable, List, Tuple from thinc.types import Floats2d from thinc.api import chain, clone, list2ragged, reduce_mean, residual -from thinc.api import Model, Maxout, Linear +from thinc.api import Model, Maxout, Linear, noop, tuplify, Ragged from ...util import registry from ...kb import KnowledgeBase, Candidate, get_candidates from ...vocab import Vocab from ...tokens import Span, Doc +from ..extract_spans import extract_spans +from ...errors import Errors -@registry.architectures("spacy.EntityLinker.v1") +@registry.architectures("spacy.EntityLinker.v2") def build_nel_encoder( tok2vec: Model, nO: Optional[int] = None ) -> Model[List[Doc], Floats2d]: - with Model.define_operators({">>": chain, "**": clone}): + with Model.define_operators({">>": chain, "&": tuplify}): token_width = tok2vec.maybe_get_dim("nO") output_layer = Linear(nO=nO, nI=token_width) model = ( - tok2vec - >> list2ragged() + ((tok2vec >> list2ragged()) & build_span_maker()) + >> extract_spans() >> reduce_mean() >> residual(Maxout(nO=token_width, nI=token_width, nP=2, dropout=0.0)) # type: ignore[arg-type] >> output_layer ) model.set_ref("output_layer", output_layer) model.set_ref("tok2vec", tok2vec) + # flag to show this isn't legacy + model.attrs["include_span_maker"] = True return model +def build_span_maker(n_sents: int = 0) -> Model: + model: Model = Model("span_maker", forward=span_maker_forward) + model.attrs["n_sents"] = n_sents + return model + + +def span_maker_forward(model, docs: List[Doc], is_train) -> Tuple[Ragged, Callable]: + ops = model.ops + n_sents = model.attrs["n_sents"] + candidates = [] + for doc in docs: + cands = [] + try: + sentences = [s for s in doc.sents] + except ValueError: + # no sentence info, normal in initialization + for tok in doc: + tok.is_sent_start = tok.i == 0 + sentences = [doc[:]] + for ent in doc.ents: + try: + # find the sentence in the list of sentences. + sent_index = sentences.index(ent.sent) + except AttributeError: + # Catch the exception when ent.sent is None and provide a user-friendly warning + raise RuntimeError(Errors.E030) from None + # get n previous sentences, if there are any + start_sentence = max(0, sent_index - n_sents) + # get n posterior sentences, or as many < n as there are + end_sentence = min(len(sentences) - 1, sent_index + n_sents) + # get token positions + start_token = sentences[start_sentence].start + end_token = sentences[end_sentence].end + # save positions for extraction + cands.append((start_token, end_token)) + + candidates.append(ops.asarray2i(cands)) + candlens = ops.asarray1i([len(cands) for cands in candidates]) + candidates = ops.xp.concatenate(candidates) + outputs = Ragged(candidates, candlens) + # because this is just rearranging docs, the backprop does nothing + return outputs, lambda x: [] + + @registry.misc("spacy.KBFromFile.v1") def load_kb(kb_path: Path) -> Callable[[Vocab], KnowledgeBase]: def kb_from_file(vocab): diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index 1169e898d..89e7576bf 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -6,17 +6,17 @@ import srsly import random from thinc.api import CosineDistance, Model, Optimizer, Config from thinc.api import set_dropout_rate -import warnings from ..kb import KnowledgeBase, Candidate from ..ml import empty_kb from ..tokens import Doc, Span from .pipe import deserialize_config +from .legacy.entity_linker import EntityLinker_v1 from .trainable_pipe import TrainablePipe from ..language import Language from ..vocab import Vocab from ..training import Example, validate_examples, validate_get_examples -from ..errors import Errors, Warnings +from ..errors import Errors from ..util import SimpleFrozenList, registry from .. import util from ..scorer import Scorer @@ -26,7 +26,7 @@ BACKWARD_OVERWRITE = True default_model_config = """ [model] -@architectures = "spacy.EntityLinker.v1" +@architectures = "spacy.EntityLinker.v2" [model.tok2vec] @architectures = "spacy.HashEmbedCNN.v2" @@ -55,6 +55,7 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"] "get_candidates": {"@misc": "spacy.CandidateGenerator.v1"}, "overwrite": True, "scorer": {"@scorers": "spacy.entity_linker_scorer.v1"}, + "use_gold_ents": True, }, default_score_weights={ "nel_micro_f": 1.0, @@ -75,6 +76,7 @@ def make_entity_linker( get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], overwrite: bool, scorer: Optional[Callable], + use_gold_ents: bool, ): """Construct an EntityLinker component. @@ -90,6 +92,22 @@ def make_entity_linker( produces a list of candidates, given a certain knowledge base and a textual mention. scorer (Optional[Callable]): The scoring method. """ + + if not model.attrs.get("include_span_maker", False): + # The only difference in arguments here is that use_gold_ents is not available + return EntityLinker_v1( + nlp.vocab, + model, + name, + labels_discard=labels_discard, + n_sents=n_sents, + incl_prior=incl_prior, + incl_context=incl_context, + entity_vector_length=entity_vector_length, + get_candidates=get_candidates, + overwrite=overwrite, + scorer=scorer, + ) return EntityLinker( nlp.vocab, model, @@ -102,6 +120,7 @@ def make_entity_linker( get_candidates=get_candidates, overwrite=overwrite, scorer=scorer, + use_gold_ents=use_gold_ents, ) @@ -136,6 +155,7 @@ class EntityLinker(TrainablePipe): get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], overwrite: bool = BACKWARD_OVERWRITE, scorer: Optional[Callable] = entity_linker_score, + use_gold_ents: bool, ) -> None: """Initialize an entity linker. @@ -152,6 +172,8 @@ class EntityLinker(TrainablePipe): produces a list of candidates, given a certain knowledge base and a textual mention. scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links. + use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another + component must provide entity annotations. DOCS: https://spacy.io/api/entitylinker#init """ @@ -169,6 +191,7 @@ class EntityLinker(TrainablePipe): # create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'. self.kb = empty_kb(entity_vector_length)(self.vocab) self.scorer = scorer + self.use_gold_ents = use_gold_ents def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]): """Define the KB of this pipe by providing a function that will @@ -212,14 +235,48 @@ class EntityLinker(TrainablePipe): doc_sample = [] vector_sample = [] for example in islice(get_examples(), 10): - doc_sample.append(example.x) + doc = example.x + if self.use_gold_ents: + doc.ents = example.y.ents + doc_sample.append(doc) vector_sample.append(self.model.ops.alloc1f(nO)) assert len(doc_sample) > 0, Errors.E923.format(name=self.name) assert len(vector_sample) > 0, Errors.E923.format(name=self.name) + + # XXX In order for size estimation to work, there has to be at least + # one entity. It's not used for training so it doesn't have to be real, + # so we add a fake one if none are present. + # We can't use Doc.has_annotation here because it can be True for docs + # that have been through an NER component but got no entities. + has_annotations = any([doc.ents for doc in doc_sample]) + if not has_annotations: + doc = doc_sample[0] + ent = doc[0:1] + ent.label_ = "XXX" + doc.ents = (ent,) + self.model.initialize( X=doc_sample, Y=self.model.ops.asarray(vector_sample, dtype="float32") ) + if not has_annotations: + # Clean up dummy annotation + doc.ents = [] + + def batch_has_learnable_example(self, examples): + """Check if a batch contains a learnable example. + + If one isn't present, then the update step needs to be skipped. + """ + + for eg in examples: + for ent in eg.predicted.ents: + candidates = list(self.get_candidates(self.kb, ent)) + if candidates: + return True + + return False + def update( self, examples: Iterable[Example], @@ -247,35 +304,29 @@ class EntityLinker(TrainablePipe): if not examples: return losses validate_examples(examples, "EntityLinker.update") - sentence_docs = [] - for eg in examples: - sentences = [s for s in eg.reference.sents] - kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True) - for ent in eg.reference.ents: - # KB ID of the first token is the same as the whole span - kb_id = kb_ids[ent.start] - if kb_id: - try: - # find the sentence in the list of sentences. - sent_index = sentences.index(ent.sent) - except AttributeError: - # Catch the exception when ent.sent is None and provide a user-friendly warning - raise RuntimeError(Errors.E030) from None - # get n previous sentences, if there are any - start_sentence = max(0, sent_index - self.n_sents) - # get n posterior sentences, or as many < n as there are - end_sentence = min(len(sentences) - 1, sent_index + self.n_sents) - # get token positions - start_token = sentences[start_sentence].start - end_token = sentences[end_sentence].end - # append that span as a doc to training - sent_doc = eg.predicted[start_token:end_token].as_doc() - sentence_docs.append(sent_doc) + set_dropout_rate(self.model, drop) - if not sentence_docs: - warnings.warn(Warnings.W093.format(name="Entity Linker")) + docs = [eg.predicted for eg in examples] + # save to restore later + old_ents = [doc.ents for doc in docs] + + for doc, ex in zip(docs, examples): + if self.use_gold_ents: + doc.ents = ex.reference.ents + else: + # only keep matching ents + doc.ents = ex.get_matching_ents() + + # make sure we have something to learn from, if not, short-circuit + if not self.batch_has_learnable_example(examples): return losses - sentence_encodings, bp_context = self.model.begin_update(sentence_docs) + + sentence_encodings, bp_context = self.model.begin_update(docs) + + # now restore the ents + for doc, old in zip(docs, old_ents): + doc.ents = old + loss, d_scores = self.get_loss( sentence_encodings=sentence_encodings, examples=examples ) @@ -288,24 +339,38 @@ class EntityLinker(TrainablePipe): def get_loss(self, examples: Iterable[Example], sentence_encodings: Floats2d): validate_examples(examples, "EntityLinker.get_loss") entity_encodings = [] + eidx = 0 # indices in gold entities to keep + keep_ents = [] # indices in sentence_encodings to keep + for eg in examples: kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True) + for ent in eg.reference.ents: kb_id = kb_ids[ent.start] if kb_id: entity_encoding = self.kb.get_vector(kb_id) entity_encodings.append(entity_encoding) + keep_ents.append(eidx) + + eidx += 1 entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32") - if sentence_encodings.shape != entity_encodings.shape: + selected_encodings = sentence_encodings[keep_ents] + + # If the entity encodings list is empty, then + if selected_encodings.shape != entity_encodings.shape: err = Errors.E147.format( method="get_loss", msg="gold entities do not match up" ) raise RuntimeError(err) # TODO: fix typing issue here - gradients = self.distance.get_grad(sentence_encodings, entity_encodings) # type: ignore - loss = self.distance.get_loss(sentence_encodings, entity_encodings) # type: ignore + gradients = self.distance.get_grad(selected_encodings, entity_encodings) # type: ignore + # to match the input size, we need to give a zero gradient for items not in the kb + out = self.model.ops.alloc2f(*sentence_encodings.shape) + out[keep_ents] = gradients + + loss = self.distance.get_loss(selected_encodings, entity_encodings) # type: ignore loss = loss / len(entity_encodings) - return float(loss), gradients + return float(loss), out def predict(self, docs: Iterable[Doc]) -> List[str]: """Apply the pipeline's model to a batch of docs, without modifying them. diff --git a/spacy/pipeline/legacy/__init__.py b/spacy/pipeline/legacy/__init__.py new file mode 100644 index 000000000..f216840dc --- /dev/null +++ b/spacy/pipeline/legacy/__init__.py @@ -0,0 +1,3 @@ +from .entity_linker import EntityLinker_v1 + +__all__ = ["EntityLinker_v1"] diff --git a/spacy/pipeline/legacy/entity_linker.py b/spacy/pipeline/legacy/entity_linker.py new file mode 100644 index 000000000..6440c18e5 --- /dev/null +++ b/spacy/pipeline/legacy/entity_linker.py @@ -0,0 +1,427 @@ +# This file is present to provide a prior version of the EntityLinker component +# for backwards compatability. For details see #9669. + +from typing import Optional, Iterable, Callable, Dict, Union, List, Any +from thinc.types import Floats2d +from pathlib import Path +from itertools import islice +import srsly +import random +from thinc.api import CosineDistance, Model, Optimizer, Config +from thinc.api import set_dropout_rate +import warnings + +from ...kb import KnowledgeBase, Candidate +from ...ml import empty_kb +from ...tokens import Doc, Span +from ..pipe import deserialize_config +from ..trainable_pipe import TrainablePipe +from ...language import Language +from ...vocab import Vocab +from ...training import Example, validate_examples, validate_get_examples +from ...errors import Errors, Warnings +from ...util import SimpleFrozenList, registry +from ... import util +from ...scorer import Scorer + +# See #9050 +BACKWARD_OVERWRITE = True + + +def entity_linker_score(examples, **kwargs): + return Scorer.score_links(examples, negative_labels=[EntityLinker_v1.NIL], **kwargs) + + +class EntityLinker_v1(TrainablePipe): + """Pipeline component for named entity linking. + + DOCS: https://spacy.io/api/entitylinker + """ + + NIL = "NIL" # string used to refer to a non-existing link + + def __init__( + self, + vocab: Vocab, + model: Model, + name: str = "entity_linker", + *, + labels_discard: Iterable[str], + n_sents: int, + incl_prior: bool, + incl_context: bool, + entity_vector_length: int, + get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], + overwrite: bool = BACKWARD_OVERWRITE, + scorer: Optional[Callable] = entity_linker_score, + ) -> None: + """Initialize an entity linker. + + vocab (Vocab): The shared vocabulary. + model (thinc.api.Model): The Thinc Model powering the pipeline component. + name (str): The component instance name, used to add entries to the + losses during training. + labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction. + n_sents (int): The number of neighbouring sentences to take into account. + incl_prior (bool): Whether or not to include prior probabilities from the KB in the model. + incl_context (bool): Whether or not to include the local context in the model. + entity_vector_length (int): Size of encoding vectors in the KB. + get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that + produces a list of candidates, given a certain knowledge base and a textual mention. + scorer (Optional[Callable]): The scoring method. Defaults to + Scorer.score_links. + + DOCS: https://spacy.io/api/entitylinker#init + """ + self.vocab = vocab + self.model = model + self.name = name + self.labels_discard = list(labels_discard) + self.n_sents = n_sents + self.incl_prior = incl_prior + self.incl_context = incl_context + self.get_candidates = get_candidates + self.cfg: Dict[str, Any] = {"overwrite": overwrite} + self.distance = CosineDistance(normalize=False) + # how many neighbour sentences to take into account + # create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'. + self.kb = empty_kb(entity_vector_length)(self.vocab) + self.scorer = scorer + + def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]): + """Define the KB of this pipe by providing a function that will + create it using this object's vocab.""" + if not callable(kb_loader): + raise ValueError(Errors.E885.format(arg_type=type(kb_loader))) + + self.kb = kb_loader(self.vocab) + + def validate_kb(self) -> None: + # Raise an error if the knowledge base is not initialized. + if self.kb is None: + raise ValueError(Errors.E1018.format(name=self.name)) + if len(self.kb) == 0: + raise ValueError(Errors.E139.format(name=self.name)) + + def initialize( + self, + get_examples: Callable[[], Iterable[Example]], + *, + nlp: Optional[Language] = None, + kb_loader: Optional[Callable[[Vocab], KnowledgeBase]] = None, + ): + """Initialize the pipe for training, using a representative set + of data examples. + + get_examples (Callable[[], Iterable[Example]]): Function that + returns a representative sample of gold-standard Example objects. + nlp (Language): The current nlp object the component is part of. + kb_loader (Callable[[Vocab], KnowledgeBase]): A function that creates a KnowledgeBase from a Vocab instance. + Note that providing this argument, will overwrite all data accumulated in the current KB. + Use this only when loading a KB as-such from file. + + DOCS: https://spacy.io/api/entitylinker#initialize + """ + validate_get_examples(get_examples, "EntityLinker_v1.initialize") + if kb_loader is not None: + self.set_kb(kb_loader) + self.validate_kb() + nO = self.kb.entity_vector_length + doc_sample = [] + vector_sample = [] + for example in islice(get_examples(), 10): + doc_sample.append(example.x) + vector_sample.append(self.model.ops.alloc1f(nO)) + assert len(doc_sample) > 0, Errors.E923.format(name=self.name) + assert len(vector_sample) > 0, Errors.E923.format(name=self.name) + self.model.initialize( + X=doc_sample, Y=self.model.ops.asarray(vector_sample, dtype="float32") + ) + + def update( + self, + examples: Iterable[Example], + *, + drop: float = 0.0, + sgd: Optional[Optimizer] = None, + losses: Optional[Dict[str, float]] = None, + ) -> Dict[str, float]: + """Learn from a batch of documents and gold-standard information, + updating the pipe's model. Delegates to predict and get_loss. + + examples (Iterable[Example]): A batch of Example objects. + drop (float): The dropout rate. + sgd (thinc.api.Optimizer): The optimizer. + losses (Dict[str, float]): Optional record of the loss during training. + Updated using the component name as the key. + RETURNS (Dict[str, float]): The updated losses dictionary. + + DOCS: https://spacy.io/api/entitylinker#update + """ + self.validate_kb() + if losses is None: + losses = {} + losses.setdefault(self.name, 0.0) + if not examples: + return losses + validate_examples(examples, "EntityLinker_v1.update") + sentence_docs = [] + for eg in examples: + sentences = [s for s in eg.reference.sents] + kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True) + for ent in eg.reference.ents: + # KB ID of the first token is the same as the whole span + kb_id = kb_ids[ent.start] + if kb_id: + try: + # find the sentence in the list of sentences. + sent_index = sentences.index(ent.sent) + except AttributeError: + # Catch the exception when ent.sent is None and provide a user-friendly warning + raise RuntimeError(Errors.E030) from None + # get n previous sentences, if there are any + start_sentence = max(0, sent_index - self.n_sents) + # get n posterior sentences, or as many < n as there are + end_sentence = min(len(sentences) - 1, sent_index + self.n_sents) + # get token positions + start_token = sentences[start_sentence].start + end_token = sentences[end_sentence].end + # append that span as a doc to training + sent_doc = eg.predicted[start_token:end_token].as_doc() + sentence_docs.append(sent_doc) + set_dropout_rate(self.model, drop) + if not sentence_docs: + warnings.warn(Warnings.W093.format(name="Entity Linker")) + return losses + sentence_encodings, bp_context = self.model.begin_update(sentence_docs) + loss, d_scores = self.get_loss( + sentence_encodings=sentence_encodings, examples=examples + ) + bp_context(d_scores) + if sgd is not None: + self.finish_update(sgd) + losses[self.name] += loss + return losses + + def get_loss(self, examples: Iterable[Example], sentence_encodings: Floats2d): + validate_examples(examples, "EntityLinker_v1.get_loss") + entity_encodings = [] + for eg in examples: + kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True) + for ent in eg.reference.ents: + kb_id = kb_ids[ent.start] + if kb_id: + entity_encoding = self.kb.get_vector(kb_id) + entity_encodings.append(entity_encoding) + entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32") + if sentence_encodings.shape != entity_encodings.shape: + err = Errors.E147.format( + method="get_loss", msg="gold entities do not match up" + ) + raise RuntimeError(err) + # TODO: fix typing issue here + gradients = self.distance.get_grad(sentence_encodings, entity_encodings) # type: ignore + loss = self.distance.get_loss(sentence_encodings, entity_encodings) # type: ignore + loss = loss / len(entity_encodings) + return float(loss), gradients + + def predict(self, docs: Iterable[Doc]) -> List[str]: + """Apply the pipeline's model to a batch of docs, without modifying them. + Returns the KB IDs for each entity in each doc, including NIL if there is + no prediction. + + docs (Iterable[Doc]): The documents to predict. + RETURNS (List[str]): The models prediction for each document. + + DOCS: https://spacy.io/api/entitylinker#predict + """ + self.validate_kb() + entity_count = 0 + final_kb_ids: List[str] = [] + if not docs: + return final_kb_ids + if isinstance(docs, Doc): + docs = [docs] + for i, doc in enumerate(docs): + sentences = [s for s in doc.sents] + if len(doc) > 0: + # Looping through each entity (TODO: rewrite) + for ent in doc.ents: + sent = ent.sent + sent_index = sentences.index(sent) + assert sent_index >= 0 + # get n_neighbour sentences, clipped to the length of the document + start_sentence = max(0, sent_index - self.n_sents) + end_sentence = min(len(sentences) - 1, sent_index + self.n_sents) + start_token = sentences[start_sentence].start + end_token = sentences[end_sentence].end + sent_doc = doc[start_token:end_token].as_doc() + # currently, the context is the same for each entity in a sentence (should be refined) + xp = self.model.ops.xp + if self.incl_context: + sentence_encoding = self.model.predict([sent_doc])[0] + sentence_encoding_t = sentence_encoding.T + sentence_norm = xp.linalg.norm(sentence_encoding_t) + entity_count += 1 + if ent.label_ in self.labels_discard: + # ignoring this entity - setting to NIL + final_kb_ids.append(self.NIL) + else: + candidates = list(self.get_candidates(self.kb, ent)) + if not candidates: + # no prediction possible for this entity - setting to NIL + final_kb_ids.append(self.NIL) + elif len(candidates) == 1: + # shortcut for efficiency reasons: take the 1 candidate + # TODO: thresholding + final_kb_ids.append(candidates[0].entity_) + else: + random.shuffle(candidates) + # set all prior probabilities to 0 if incl_prior=False + prior_probs = xp.asarray([c.prior_prob for c in candidates]) + if not self.incl_prior: + prior_probs = xp.asarray([0.0 for _ in candidates]) + scores = prior_probs + # add in similarity from the context + if self.incl_context: + entity_encodings = xp.asarray( + [c.entity_vector for c in candidates] + ) + entity_norm = xp.linalg.norm(entity_encodings, axis=1) + if len(entity_encodings) != len(prior_probs): + raise RuntimeError( + Errors.E147.format( + method="predict", + msg="vectors not of equal length", + ) + ) + # cosine similarity + sims = xp.dot(entity_encodings, sentence_encoding_t) / ( + sentence_norm * entity_norm + ) + if sims.shape != prior_probs.shape: + raise ValueError(Errors.E161) + scores = prior_probs + sims - (prior_probs * sims) + # TODO: thresholding + best_index = scores.argmax().item() + best_candidate = candidates[best_index] + final_kb_ids.append(best_candidate.entity_) + if not (len(final_kb_ids) == entity_count): + err = Errors.E147.format( + method="predict", msg="result variables not of equal length" + ) + raise RuntimeError(err) + return final_kb_ids + + def set_annotations(self, docs: Iterable[Doc], kb_ids: List[str]) -> None: + """Modify a batch of documents, using pre-computed scores. + + docs (Iterable[Doc]): The documents to modify. + kb_ids (List[str]): The IDs to set, produced by EntityLinker.predict. + + DOCS: https://spacy.io/api/entitylinker#set_annotations + """ + count_ents = len([ent for doc in docs for ent in doc.ents]) + if count_ents != len(kb_ids): + raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids))) + i = 0 + overwrite = self.cfg["overwrite"] + for doc in docs: + for ent in doc.ents: + kb_id = kb_ids[i] + i += 1 + for token in ent: + if token.ent_kb_id == 0 or overwrite: + token.ent_kb_id_ = kb_id + + def to_bytes(self, *, exclude=tuple()): + """Serialize the pipe to a bytestring. + + exclude (Iterable[str]): String names of serialization fields to exclude. + RETURNS (bytes): The serialized object. + + DOCS: https://spacy.io/api/entitylinker#to_bytes + """ + self._validate_serialization_attrs() + serialize = {} + if hasattr(self, "cfg") and self.cfg is not None: + serialize["cfg"] = lambda: srsly.json_dumps(self.cfg) + serialize["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude) + serialize["kb"] = self.kb.to_bytes + serialize["model"] = self.model.to_bytes + return util.to_bytes(serialize, exclude) + + def from_bytes(self, bytes_data, *, exclude=tuple()): + """Load the pipe from a bytestring. + + exclude (Iterable[str]): String names of serialization fields to exclude. + RETURNS (TrainablePipe): The loaded object. + + DOCS: https://spacy.io/api/entitylinker#from_bytes + """ + self._validate_serialization_attrs() + + def load_model(b): + try: + self.model.from_bytes(b) + except AttributeError: + raise ValueError(Errors.E149) from None + + deserialize = {} + if hasattr(self, "cfg") and self.cfg is not None: + deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b)) + deserialize["vocab"] = lambda b: self.vocab.from_bytes(b, exclude=exclude) + deserialize["kb"] = lambda b: self.kb.from_bytes(b) + deserialize["model"] = load_model + util.from_bytes(bytes_data, deserialize, exclude) + return self + + def to_disk( + self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList() + ) -> None: + """Serialize the pipe to disk. + + path (str / Path): Path to a directory. + exclude (Iterable[str]): String names of serialization fields to exclude. + + DOCS: https://spacy.io/api/entitylinker#to_disk + """ + serialize = {} + serialize["vocab"] = lambda p: self.vocab.to_disk(p, exclude=exclude) + serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg) + serialize["kb"] = lambda p: self.kb.to_disk(p) + serialize["model"] = lambda p: self.model.to_disk(p) + util.to_disk(path, serialize, exclude) + + def from_disk( + self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList() + ) -> "EntityLinker_v1": + """Load the pipe from disk. Modifies the object in place and returns it. + + path (str / Path): Path to a directory. + exclude (Iterable[str]): String names of serialization fields to exclude. + RETURNS (EntityLinker): The modified EntityLinker object. + + DOCS: https://spacy.io/api/entitylinker#from_disk + """ + + def load_model(p): + try: + with p.open("rb") as infile: + self.model.from_bytes(infile.read()) + except AttributeError: + raise ValueError(Errors.E149) from None + + deserialize: Dict[str, Callable[[Any], Any]] = {} + deserialize["cfg"] = lambda p: self.cfg.update(deserialize_config(p)) + deserialize["vocab"] = lambda p: self.vocab.from_disk(p, exclude=exclude) + deserialize["kb"] = lambda p: self.kb.from_disk(p) + deserialize["model"] = load_model + util.from_disk(path, deserialize, exclude) + return self + + def rehearse(self, examples, *, sgd=None, losses=None, **config): + raise NotImplementedError + + def add_label(self, label): + raise NotImplementedError diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index 3740e430e..7d1382741 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -9,6 +9,9 @@ from spacy.compat import pickle from spacy.kb import Candidate, KnowledgeBase, get_candidates from spacy.lang.en import English from spacy.ml import load_kb +from spacy.pipeline import EntityLinker +from spacy.pipeline.legacy import EntityLinker_v1 +from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL from spacy.scorer import Scorer from spacy.tests.util import make_tempdir from spacy.tokens import Span @@ -168,6 +171,45 @@ def test_issue7065_b(): assert doc +def test_no_entities(): + # Test that having no entities doesn't crash the model + TRAIN_DATA = [ + ( + "The sky is blue.", + { + "sent_starts": [1, 0, 0, 0, 0], + }, + ) + ] + nlp = English() + vector_length = 3 + train_examples = [] + for text, annotation in TRAIN_DATA: + doc = nlp(text) + train_examples.append(Example.from_dict(doc, annotation)) + + def create_kb(vocab): + # create artificial KB + mykb = KnowledgeBase(vocab, entity_vector_length=vector_length) + mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3]) + mykb.add_alias("Russ Cochran", ["Q2146908"], [0.9]) + return mykb + + # Create and train the Entity Linker + entity_linker = nlp.add_pipe("entity_linker", last=True) + entity_linker.set_kb(create_kb) + optimizer = nlp.initialize(get_examples=lambda: train_examples) + for i in range(2): + losses = {} + nlp.update(train_examples, sgd=optimizer, losses=losses) + + # adding additional components that are required for the entity_linker + nlp.add_pipe("sentencizer", first=True) + + # this will run the pipeline on the examples and shouldn't crash + results = nlp.evaluate(train_examples) + + def test_partial_links(): # Test that having some entities on the doc without gold links, doesn't crash TRAIN_DATA = [ @@ -650,7 +692,7 @@ TRAIN_DATA = [ "sent_starts": [1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}), ("Russ Cochran his reprints include EC Comics.", {"links": {(0, 12): {"Q7381115": 1.0, "Q2146908": 0.0}}, - "entities": [(0, 12, "PERSON")], + "entities": [(0, 12, "PERSON"), (34, 43, "ART")], "sent_starts": [1, -1, 0, 0, 0, 0, 0, 0]}), ("Russ Cochran has been publishing comic art.", {"links": {(0, 12): {"Q7381115": 1.0, "Q2146908": 0.0}}, @@ -693,6 +735,7 @@ def test_overfitting_IO(): # Create the Entity Linker component and add it to the pipeline entity_linker = nlp.add_pipe("entity_linker", last=True) + assert isinstance(entity_linker, EntityLinker) entity_linker.set_kb(create_kb) assert "Q2146908" in entity_linker.vocab.strings assert "Q2146908" in entity_linker.kb.vocab.strings @@ -922,3 +965,109 @@ def test_scorer_links(): assert scores["nel_micro_p"] == 2 / 3 assert scores["nel_micro_r"] == 2 / 4 + + +# fmt: off +@pytest.mark.parametrize( + "name,config", + [ + ("entity_linker", {"@architectures": "spacy.EntityLinker.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL}), + ("entity_linker", {"@architectures": "spacy.EntityLinker.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL}), + ], +) +# fmt: on +def test_legacy_architectures(name, config): + # Ensure that the legacy architectures still work + vector_length = 3 + nlp = English() + + train_examples = [] + for text, annotation in TRAIN_DATA: + doc = nlp.make_doc(text) + train_examples.append(Example.from_dict(doc, annotation)) + + def create_kb(vocab): + mykb = KnowledgeBase(vocab, entity_vector_length=vector_length) + mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3]) + mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7]) + mykb.add_alias( + alias="Russ Cochran", + entities=["Q2146908", "Q7381115"], + probabilities=[0.5, 0.5], + ) + return mykb + + entity_linker = nlp.add_pipe(name, config={"model": config}) + if config["@architectures"] == "spacy.EntityLinker.v1": + assert isinstance(entity_linker, EntityLinker_v1) + else: + assert isinstance(entity_linker, EntityLinker) + entity_linker.set_kb(create_kb) + optimizer = nlp.initialize(get_examples=lambda: train_examples) + + for i in range(2): + losses = {} + nlp.update(train_examples, sgd=optimizer, losses=losses) + + @pytest.mark.parametrize("patterns", [ + # perfect case + [{"label": "CHARACTER", "pattern": "Kirby"}], + # typo for false negative + [{"label": "PERSON", "pattern": "Korby"}], + # random stuff for false positive + [{"label": "IS", "pattern": "is"}, {"label": "COLOR", "pattern": "pink"}], + ] + ) + def test_no_gold_ents(patterns): + # test that annotating components work + TRAIN_DATA = [ + ( + "Kirby is pink", + { + "links": {(0, 5): {"Q613241": 1.0}}, + "entities": [(0, 5, "CHARACTER")], + "sent_starts": [1, 0, 0], + }, + ) + ] + nlp = English() + vector_length = 3 + train_examples = [] + for text, annotation in TRAIN_DATA: + doc = nlp(text) + train_examples.append(Example.from_dict(doc, annotation)) + + # Create a ruler to mark entities + ruler = nlp.add_pipe("entity_ruler") + ruler.add_patterns(patterns) + + # Apply ruler to examples. In a real pipeline this would be an annotating component. + for eg in train_examples: + eg.predicted = ruler(eg.predicted) + + def create_kb(vocab): + # create artificial KB + mykb = KnowledgeBase(vocab, entity_vector_length=vector_length) + mykb.add_entity(entity="Q613241", freq=12, entity_vector=[6, -4, 3]) + mykb.add_alias("Kirby", ["Q613241"], [0.9]) + # Placeholder + mykb.add_entity(entity="pink", freq=12, entity_vector=[7, 2, -5]) + mykb.add_alias("pink", ["pink"], [0.9]) + return mykb + + + # Create and train the Entity Linker + entity_linker = nlp.add_pipe("entity_linker", config={"use_gold_ents": False}, last=True) + entity_linker.set_kb(create_kb) + assert entity_linker.use_gold_ents == False + + optimizer = nlp.initialize(get_examples=lambda: train_examples) + for i in range(2): + losses = {} + nlp.update(train_examples, sgd=optimizer, losses=losses) + + # adding additional components that are required for the entity_linker + nlp.add_pipe("sentencizer", first=True) + + # this will run the pipeline on the examples and shouldn't crash + results = nlp.evaluate(train_examples) diff --git a/spacy/training/example.pyx b/spacy/training/example.pyx index d792c9bbf..a2c5e08e9 100644 --- a/spacy/training/example.pyx +++ b/spacy/training/example.pyx @@ -256,6 +256,29 @@ cdef class Example: x_ents, x_tags = self.get_aligned_ents_and_ner() return x_tags + def get_matching_ents(self, check_label=True): + """Return entities that are shared between predicted and reference docs. + + If `check_label` is True, entities must have matching labels to be + kept. Otherwise only the character indices need to match. + """ + gold = {} + for ent in self.reference: + gold[(ent.start_char, ent.end_char)] = ent.label + + keep = [] + for ent in self.predicted: + key = (ent.start_char, ent.end_char) + if key not in gold: + continue + + if check_label and ent.label != gold[key]: + continue + + keep.append(ent) + + return keep + def to_dict(self): return { "doc_annotation": { diff --git a/website/docs/api/architectures.md b/website/docs/api/architectures.md index 07b76393f..5fb3546a7 100644 --- a/website/docs/api/architectures.md +++ b/website/docs/api/architectures.md @@ -858,13 +858,13 @@ into the "real world". This requires 3 main components: - A machine learning [`Model`](https://thinc.ai/docs/api-model) that picks the most plausible ID from the set of candidates. -### spacy.EntityLinker.v1 {#EntityLinker} +### spacy.EntityLinker.v2 {#EntityLinker} > #### Example Config > > ```ini > [model] -> @architectures = "spacy.EntityLinker.v1" +> @architectures = "spacy.EntityLinker.v2" > nO = null > > [model.tok2vec] diff --git a/website/docs/api/entitylinker.md b/website/docs/api/entitylinker.md index 3d3372679..8e0d6087a 100644 --- a/website/docs/api/entitylinker.md +++ b/website/docs/api/entitylinker.md @@ -59,6 +59,7 @@ architectures and their arguments and hyperparameters. | `incl_context` | Whether or not to include the local context in the model. Defaults to `True`. ~~bool~~ | | `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [EntityLinker](/api/architectures#EntityLinker). ~~Model~~ | | `entity_vector_length` | Size of encoding vectors in the KB. Defaults to `64`. ~~int~~ | +| `use_gold_ents` | Whether to copy entities from the gold docs or not. Defaults to `True`. If `False`, entities must be set in the training data or by an annotating component in the pipeline. ~~int~~ | | `get_candidates` | Function that generates plausible candidates for a given `Span` object. Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~ | | `overwrite` 3.2 | Whether existing annotation is overwritten. Defaults to `True`. ~~bool~~ | | `scorer` 3.2 | The scoring method. Defaults to [`Scorer.score_links`](/api/scorer#score_links). ~~Optional[Callable]~~ | From d89dac4066b3a245adb3982709bb7bb6eb9b9d63 Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Fri, 4 Mar 2022 11:07:45 +0100 Subject: [PATCH 002/424] hook up meta in load_model_from_config (#10400) --- spacy/util.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/spacy/util.py b/spacy/util.py index 2a8b9f5cc..66e257dd8 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -485,13 +485,16 @@ def load_model_from_path( config_path = model_path / "config.cfg" overrides = dict_to_dot(config) config = load_config(config_path, overrides=overrides) - nlp = load_model_from_config(config, vocab=vocab, disable=disable, exclude=exclude) + nlp = load_model_from_config( + config, vocab=vocab, disable=disable, exclude=exclude, meta=meta + ) return nlp.from_disk(model_path, exclude=exclude, overrides=overrides) def load_model_from_config( config: Union[Dict[str, Any], Config], *, + meta: Dict[str, Any] = SimpleFrozenDict(), vocab: Union["Vocab", bool] = True, disable: Iterable[str] = SimpleFrozenList(), exclude: Iterable[str] = SimpleFrozenList(), @@ -529,6 +532,7 @@ def load_model_from_config( exclude=exclude, auto_fill=auto_fill, validate=validate, + meta=meta, ) return nlp From 6f4f57f3172112eb34336b0d6c0f0a0c930a5d1c Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Mon, 7 Mar 2022 18:41:03 +0900 Subject: [PATCH 003/424] Update Issue Templates (#10446) * Remove mention of python 3.10 wheels These were released a while ago, just forgot to remove this notice. * Add note about Discussions --- .github/ISSUE_TEMPLATE/01_bugs.md | 2 ++ .github/ISSUE_TEMPLATE/config.yml | 3 --- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/01_bugs.md b/.github/ISSUE_TEMPLATE/01_bugs.md index 768832c24..255a5241e 100644 --- a/.github/ISSUE_TEMPLATE/01_bugs.md +++ b/.github/ISSUE_TEMPLATE/01_bugs.md @@ -4,6 +4,8 @@ about: Use this template if you came across a bug or unexpected behaviour differ --- + + ## How to reproduce the behaviour diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index fce1a1064..31f89f917 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,8 +1,5 @@ blank_issues_enabled: false contact_links: - - name: ⚠️ Python 3.10 Support - url: https://github.com/explosion/spaCy/discussions/9418 - about: Python 3.10 wheels haven't been released yet, see the link for details. - name: 🗯 Discussions Forum url: https://github.com/explosion/spaCy/discussions about: Install issues, usage questions, general discussion and anything else that isn't a bug report. From a6d5824e5f8361078f4075541e7fd41b304cf379 Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Mon, 7 Mar 2022 12:47:26 +0100 Subject: [PATCH 004/424] added classy-classification package to spacy universe (#10393) * Update universe.json added classy-classification to Spacy universe * Update universe.json added classy-classification to the spacy universe resources * Update universe.json corrected a small typo in json * Update website/meta/universe.json Co-authored-by: Sofie Van Landeghem * Update website/meta/universe.json Co-authored-by: Sofie Van Landeghem * Update website/meta/universe.json Co-authored-by: Sofie Van Landeghem * Update universe.json processed merge feedback * Update universe.json Co-authored-by: Sofie Van Landeghem --- website/meta/universe.json | 47 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/website/meta/universe.json b/website/meta/universe.json index 6374600f2..0179830d0 100644 --- a/website/meta/universe.json +++ b/website/meta/universe.json @@ -2599,6 +2599,53 @@ }, "category": ["pipeline"] }, + { + "id": "classyclassification", + "slogan": "A Python library for classy few-shot and zero-shot classification within spaCy.", + "description": "Huggingface does offer some nice models for few/zero-shot classification, but these are not tailored to multi-lingual approaches. Rasa NLU has a nice approach for this, but its too embedded in their codebase for easy usage outside of Rasa/chatbots. Additionally, it made sense to integrate sentence-transformers and Huggingface zero-shot, instead of default word embeddings. Finally, I decided to integrate with spaCy, since training a custom spaCy TextCategorizer seems like a lot of hassle if you want something quick and dirty.", + "github": "davidberenstein1957/classy-classification", + "pip": "classy-classification", + "code_example": [ + "import spacy", + "import classy_classification", + "", + "data = {", + " \"furniture\": [\"This text is about chairs.\",", + " \"Couches, benches and televisions.\",", + " \"I really need to get a new sofa.\"],", + " \"kitchen\": [\"There also exist things like fridges.\",", + " \"I hope to be getting a new stove today.\",", + " \"Do you also have some ovens.\"]", + "}", + "", + "nlp = spacy.load('en_core_web_md')", + "", + "classification_type = \"spacy_few_shot\"", + "if classification_type == \"spacy_few_shot\":", + " nlp.add_pipe(\"text_categorizer\", ", + " config={\"data\": data, \"model\": \"spacy\"}", + " )", + "elif classification_type == \"sentence_transformer_few_shot\":", + " nlp.add_pipe(\"text_categorizer\", ", + " config={\"data\": data, \"model\": \"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2\"}", + " )", + "elif classification_type == \"huggingface_zero_shot\":", + " nlp.add_pipe(\"text_categorizer\", ", + " config={\"data\": list(data.keys()), \"cat_type\": \"zero\", \"model\": \"facebook/bart-large-mnli\"}", + " )", + "", + "print(nlp(\"I am looking for kitchen appliances.\")._.cats)", + "print([doc._.cats for doc in nlp.pipe([\"I am looking for kitchen appliances.\"])])" + ], + "author": "David Berenstein", + "author_links": { + "github": "davidberenstein1957", + "website": "https://www.linkedin.com/in/david-berenstein-1bab11105/" + }, + "category": ["pipeline", "standalone"], + "tags": ["classification", "zero-shot", "few-shot", "sentence-transformers", "huggingface"], + "spacy_version": 3 + }, { "id": "blackstone", "title": "Blackstone", From 7ed7908716094ff41e4d1b2f60479f6b8356d700 Mon Sep 17 00:00:00 2001 From: jnphilipp Date: Mon, 7 Mar 2022 16:20:39 +0100 Subject: [PATCH 005/424] Add Upper Sorbian support. (#10432) * Add support basic support for upper sorbian. * Add tokenizer exceptions and tests. * Update spacy/lang/hsb/examples.py Co-authored-by: Sofie Van Landeghem --- spacy/lang/hsb/__init__.py | 18 ++++++ spacy/lang/hsb/examples.py | 15 +++++ spacy/lang/hsb/lex_attrs.py | 77 ++++++++++++++++++++++++++ spacy/lang/hsb/stop_words.py | 19 +++++++ spacy/lang/hsb/tokenizer_exceptions.py | 18 ++++++ spacy/tests/conftest.py | 5 ++ spacy/tests/lang/hsb/__init__.py | 0 spacy/tests/lang/hsb/test_text.py | 25 +++++++++ spacy/tests/lang/hsb/test_tokenizer.py | 32 +++++++++++ 9 files changed, 209 insertions(+) create mode 100644 spacy/lang/hsb/__init__.py create mode 100644 spacy/lang/hsb/examples.py create mode 100644 spacy/lang/hsb/lex_attrs.py create mode 100644 spacy/lang/hsb/stop_words.py create mode 100644 spacy/lang/hsb/tokenizer_exceptions.py create mode 100644 spacy/tests/lang/hsb/__init__.py create mode 100644 spacy/tests/lang/hsb/test_text.py create mode 100644 spacy/tests/lang/hsb/test_tokenizer.py diff --git a/spacy/lang/hsb/__init__.py b/spacy/lang/hsb/__init__.py new file mode 100644 index 000000000..034d82319 --- /dev/null +++ b/spacy/lang/hsb/__init__.py @@ -0,0 +1,18 @@ +from .lex_attrs import LEX_ATTRS +from .stop_words import STOP_WORDS +from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS +from ...language import Language, BaseDefaults + + +class UpperSorbianDefaults(BaseDefaults): + lex_attr_getters = LEX_ATTRS + stop_words = STOP_WORDS + tokenizer_exceptions = TOKENIZER_EXCEPTIONS + + +class UpperSorbian(Language): + lang = "hsb" + Defaults = UpperSorbianDefaults + + +__all__ = ["UpperSorbian"] diff --git a/spacy/lang/hsb/examples.py b/spacy/lang/hsb/examples.py new file mode 100644 index 000000000..0aafd5cee --- /dev/null +++ b/spacy/lang/hsb/examples.py @@ -0,0 +1,15 @@ +""" +Example sentences to test spaCy and its language models. + +>>> from spacy.lang.hsb.examples import sentences +>>> docs = nlp.pipe(sentences) +""" + + +sentences = [ + "To běšo wjelgin raźone a jo se wót luźi derje pśiwzeło. Tak som dožywiła wjelgin", + "Jogo pśewóźowarce stej groniłej, až how w serbskich stronach njama Santa Claus nic pytaś.", + "A ten sobuźěłaśeŕ Statneje biblioteki w Barlinju jo pśimjeł drogotne knigły bźez rukajcowu z nagima rukoma!", + "Take wobchadanje z našym kulturnym derbstwom zewšym njejźo.", + "Wopśimjeśe drugich pśinoskow jo było na wusokem niwowje, ako pśecej." +] diff --git a/spacy/lang/hsb/lex_attrs.py b/spacy/lang/hsb/lex_attrs.py new file mode 100644 index 000000000..dfda3e2db --- /dev/null +++ b/spacy/lang/hsb/lex_attrs.py @@ -0,0 +1,77 @@ +from ...attrs import LIKE_NUM + +_num_words = [ + "nul", + "jedyn", "jedna", "jedne", + "dwaj", "dwě", + "tři", "třo", + "štyri", "štyrjo", + "pjeć", + "šěsć", + "sydom", + "wosom", + "dźewjeć", + "dźesać", + "jědnaće", + "dwanaće", + "třinaće", + "štyrnaće", + "pjatnaće", + "šěsnaće", + "sydomnaće", + "wosomnaće", + "dźewjatnaće", + "dwaceći" + "třiceći", + "štyrceći", + "pjećdźesat", + "šěsćdźesat", + "sydomdźesat", + "wosomdźesat", + "dźewjećdźesat", + "sto", + "tysac", + "milion", + "miliarda", + "bilion", + "biliarda", + "trilion", + "triliarda", +] + +_ordinal_words = [ + "prěni", "prěnja", "prěnje", + "druhi", "druha", "druhe", + "třeći", "třeća", "třeće", + "štwórty", "štwórta", "štwórte", + "pjaty", "pjata", "pjate", + "šěsty", "šěsta", "šěste", + "sydmy", "sydma", "sydme", + "wosmy", "wosma", "wosme", + "dźewjaty", "dźewjata", "dźewjate", + "dźesaty", "dźesata", "dźesate", + "jědnaty", "jědnata", "jědnate", + "dwanaty", "dwanata", "dwanate" +] + + +def like_num(text): + if text.startswith(("+", "-", "±", "~")): + text = text[1:] + text = text.replace(",", "").replace(".", "") + if text.isdigit(): + return True + if text.count("/") == 1: + num, denom = text.split("/") + if num.isdigit() and denom.isdigit(): + return True + text_lower = text.lower() + if text_lower in _num_words: + return True + # Check ordinal number + if text_lower in _ordinal_words: + return True + return False + + +LEX_ATTRS = {LIKE_NUM: like_num} diff --git a/spacy/lang/hsb/stop_words.py b/spacy/lang/hsb/stop_words.py new file mode 100644 index 000000000..e6fedaf4c --- /dev/null +++ b/spacy/lang/hsb/stop_words.py @@ -0,0 +1,19 @@ +STOP_WORDS = set( + """ +a abo ale ani + +dokelž + +hdyž + +jeli jelizo + +kaž + +pak potom + +tež tohodla + +zo zoby +""".split() +) diff --git a/spacy/lang/hsb/tokenizer_exceptions.py b/spacy/lang/hsb/tokenizer_exceptions.py new file mode 100644 index 000000000..4b9a4f98a --- /dev/null +++ b/spacy/lang/hsb/tokenizer_exceptions.py @@ -0,0 +1,18 @@ +from ..tokenizer_exceptions import BASE_EXCEPTIONS +from ...symbols import ORTH, NORM +from ...util import update_exc + +_exc = dict() +for exc_data in [ + {ORTH: "mil.", NORM: "milion"}, + {ORTH: "wob.", NORM: "wobydler"}, +]: + _exc[exc_data[ORTH]] = [exc_data] + +for orth in [ + "resp.", +]: + _exc[orth] = [{ORTH: orth}] + + +TOKENIZER_EXCEPTIONS = update_exc(BASE_EXCEPTIONS, _exc) diff --git a/spacy/tests/conftest.py b/spacy/tests/conftest.py index f9266cb94..7083fd817 100644 --- a/spacy/tests/conftest.py +++ b/spacy/tests/conftest.py @@ -221,6 +221,11 @@ def ja_tokenizer(): return get_lang_class("ja")().tokenizer +@pytest.fixture(scope="session") +def hsb_tokenizer(): + return get_lang_class("hsb")().tokenizer + + @pytest.fixture(scope="session") def ko_tokenizer(): pytest.importorskip("natto") diff --git a/spacy/tests/lang/hsb/__init__.py b/spacy/tests/lang/hsb/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spacy/tests/lang/hsb/test_text.py b/spacy/tests/lang/hsb/test_text.py new file mode 100644 index 000000000..aaa4984eb --- /dev/null +++ b/spacy/tests/lang/hsb/test_text.py @@ -0,0 +1,25 @@ +import pytest + + +@pytest.mark.parametrize( + "text,match", + [ + ("10", True), + ("1", True), + ("10,000", True), + ("10,00", True), + ("jedne", True), + ("dwanaće", True), + ("milion", True), + ("sto", True), + ("załožene", False), + ("wona", False), + ("powšitkownej", False), + (",", False), + ("1/2", True), + ], +) +def test_lex_attrs_like_number(hsb_tokenizer, text, match): + tokens = hsb_tokenizer(text) + assert len(tokens) == 1 + assert tokens[0].like_num == match diff --git a/spacy/tests/lang/hsb/test_tokenizer.py b/spacy/tests/lang/hsb/test_tokenizer.py new file mode 100644 index 000000000..a3ec89ba0 --- /dev/null +++ b/spacy/tests/lang/hsb/test_tokenizer.py @@ -0,0 +1,32 @@ +import pytest + +HSB_BASIC_TOKENIZATION_TESTS = [ + ( + "Hornjoserbšćina wobsteji resp. wobsteješe z wjacorych dialektow, kotrež so zdźěla chětro wot so rozeznawachu.", + [ + "Hornjoserbšćina", + "wobsteji", + "resp.", + "wobsteješe", + "z", + "wjacorych", + "dialektow", + ",", + "kotrež", + "so", + "zdźěla", + "chětro", + "wot", + "so", + "rozeznawachu", + ".", + ], + ), +] + + +@pytest.mark.parametrize("text,expected_tokens", HSB_BASIC_TOKENIZATION_TESTS) +def test_hsb_tokenizer_basic(hsb_tokenizer, text, expected_tokens): + tokens = hsb_tokenizer(text) + token_list = [token.text for token in tokens if not token.is_space] + assert expected_tokens == token_list From 61ba5450ff5de3c1bbbca21169772d1239ee822f Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Tue, 8 Mar 2022 00:56:57 +0900 Subject: [PATCH 006/424] Fix get_matching_ents (#10451) * Fix get_matching_ents Not sure what happened here - the code prior to this commit simply does not work. It's already covered by entity linker tests, which were succeeding in the NEL PR, but couldn't possibly succeed on master. * Fix test Test was indented inside another test and so doesn't seem to have been running properly. --- spacy/tests/pipeline/test_entity_linker.py | 108 ++++++++++----------- spacy/training/example.pyx | 4 +- 2 files changed, 56 insertions(+), 56 deletions(-) diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index 7d1382741..af2132d73 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -1009,65 +1009,65 @@ def test_legacy_architectures(name, config): losses = {} nlp.update(train_examples, sgd=optimizer, losses=losses) - @pytest.mark.parametrize("patterns", [ - # perfect case - [{"label": "CHARACTER", "pattern": "Kirby"}], - # typo for false negative - [{"label": "PERSON", "pattern": "Korby"}], - # random stuff for false positive - [{"label": "IS", "pattern": "is"}, {"label": "COLOR", "pattern": "pink"}], - ] - ) - def test_no_gold_ents(patterns): - # test that annotating components work - TRAIN_DATA = [ - ( - "Kirby is pink", - { - "links": {(0, 5): {"Q613241": 1.0}}, - "entities": [(0, 5, "CHARACTER")], - "sent_starts": [1, 0, 0], - }, - ) - ] - nlp = English() - vector_length = 3 - train_examples = [] - for text, annotation in TRAIN_DATA: - doc = nlp(text) - train_examples.append(Example.from_dict(doc, annotation)) +@pytest.mark.parametrize("patterns", [ + # perfect case + [{"label": "CHARACTER", "pattern": "Kirby"}], + # typo for false negative + [{"label": "PERSON", "pattern": "Korby"}], + # random stuff for false positive + [{"label": "IS", "pattern": "is"}, {"label": "COLOR", "pattern": "pink"}], + ] +) +def test_no_gold_ents(patterns): + # test that annotating components work + TRAIN_DATA = [ + ( + "Kirby is pink", + { + "links": {(0, 5): {"Q613241": 1.0}}, + "entities": [(0, 5, "CHARACTER")], + "sent_starts": [1, 0, 0], + }, + ) + ] + nlp = English() + vector_length = 3 + train_examples = [] + for text, annotation in TRAIN_DATA: + doc = nlp(text) + train_examples.append(Example.from_dict(doc, annotation)) - # Create a ruler to mark entities - ruler = nlp.add_pipe("entity_ruler") - ruler.add_patterns(patterns) + # Create a ruler to mark entities + ruler = nlp.add_pipe("entity_ruler") + ruler.add_patterns(patterns) - # Apply ruler to examples. In a real pipeline this would be an annotating component. - for eg in train_examples: - eg.predicted = ruler(eg.predicted) + # Apply ruler to examples. In a real pipeline this would be an annotating component. + for eg in train_examples: + eg.predicted = ruler(eg.predicted) - def create_kb(vocab): - # create artificial KB - mykb = KnowledgeBase(vocab, entity_vector_length=vector_length) - mykb.add_entity(entity="Q613241", freq=12, entity_vector=[6, -4, 3]) - mykb.add_alias("Kirby", ["Q613241"], [0.9]) - # Placeholder - mykb.add_entity(entity="pink", freq=12, entity_vector=[7, 2, -5]) - mykb.add_alias("pink", ["pink"], [0.9]) - return mykb + def create_kb(vocab): + # create artificial KB + mykb = KnowledgeBase(vocab, entity_vector_length=vector_length) + mykb.add_entity(entity="Q613241", freq=12, entity_vector=[6, -4, 3]) + mykb.add_alias("Kirby", ["Q613241"], [0.9]) + # Placeholder + mykb.add_entity(entity="pink", freq=12, entity_vector=[7, 2, -5]) + mykb.add_alias("pink", ["pink"], [0.9]) + return mykb - # Create and train the Entity Linker - entity_linker = nlp.add_pipe("entity_linker", config={"use_gold_ents": False}, last=True) - entity_linker.set_kb(create_kb) - assert entity_linker.use_gold_ents == False + # Create and train the Entity Linker + entity_linker = nlp.add_pipe("entity_linker", config={"use_gold_ents": False}, last=True) + entity_linker.set_kb(create_kb) + assert entity_linker.use_gold_ents == False - optimizer = nlp.initialize(get_examples=lambda: train_examples) - for i in range(2): - losses = {} - nlp.update(train_examples, sgd=optimizer, losses=losses) + optimizer = nlp.initialize(get_examples=lambda: train_examples) + for i in range(2): + losses = {} + nlp.update(train_examples, sgd=optimizer, losses=losses) - # adding additional components that are required for the entity_linker - nlp.add_pipe("sentencizer", first=True) + # adding additional components that are required for the entity_linker + nlp.add_pipe("sentencizer", first=True) - # this will run the pipeline on the examples and shouldn't crash - results = nlp.evaluate(train_examples) + # this will run the pipeline on the examples and shouldn't crash + results = nlp.evaluate(train_examples) diff --git a/spacy/training/example.pyx b/spacy/training/example.pyx index a2c5e08e9..778dfd12a 100644 --- a/spacy/training/example.pyx +++ b/spacy/training/example.pyx @@ -263,11 +263,11 @@ cdef class Example: kept. Otherwise only the character indices need to match. """ gold = {} - for ent in self.reference: + for ent in self.reference.ents: gold[(ent.start_char, ent.end_char)] = ent.label keep = [] - for ent in self.predicted: + for ent in self.predicted.ents: key = (ent.start_char, ent.end_char) if key not in gold: continue From 5ca0dbae765c405f3aa74e32ab9e93d5ce752179 Mon Sep 17 00:00:00 2001 From: jnphilipp Date: Mon, 7 Mar 2022 16:57:14 +0100 Subject: [PATCH 007/424] Add Lower Sorbian support. (#10431) * Add support basic support for lower sorbian. * Add some test for dsb. * Update spacy/lang/dsb/examples.py Co-authored-by: Sofie Van Landeghem --- spacy/lang/dsb/__init__.py | 16 ++++++ spacy/lang/dsb/examples.py | 15 +++++ spacy/lang/dsb/lex_attrs.py | 77 ++++++++++++++++++++++++++ spacy/lang/dsb/stop_words.py | 15 +++++ spacy/tests/conftest.py | 5 ++ spacy/tests/lang/dsb/__init__.py | 0 spacy/tests/lang/dsb/test_text.py | 25 +++++++++ spacy/tests/lang/dsb/test_tokenizer.py | 29 ++++++++++ 8 files changed, 182 insertions(+) create mode 100644 spacy/lang/dsb/__init__.py create mode 100644 spacy/lang/dsb/examples.py create mode 100644 spacy/lang/dsb/lex_attrs.py create mode 100644 spacy/lang/dsb/stop_words.py create mode 100644 spacy/tests/lang/dsb/__init__.py create mode 100644 spacy/tests/lang/dsb/test_text.py create mode 100644 spacy/tests/lang/dsb/test_tokenizer.py diff --git a/spacy/lang/dsb/__init__.py b/spacy/lang/dsb/__init__.py new file mode 100644 index 000000000..c66092a0c --- /dev/null +++ b/spacy/lang/dsb/__init__.py @@ -0,0 +1,16 @@ +from .lex_attrs import LEX_ATTRS +from .stop_words import STOP_WORDS +from ...language import Language, BaseDefaults + + +class LowerSorbianDefaults(BaseDefaults): + lex_attr_getters = LEX_ATTRS + stop_words = STOP_WORDS + + +class LowerSorbian(Language): + lang = "dsb" + Defaults = LowerSorbianDefaults + + +__all__ = ["LowerSorbian"] diff --git a/spacy/lang/dsb/examples.py b/spacy/lang/dsb/examples.py new file mode 100644 index 000000000..28b8c41f1 --- /dev/null +++ b/spacy/lang/dsb/examples.py @@ -0,0 +1,15 @@ +""" +Example sentences to test spaCy and its language models. + +>>> from spacy.lang.dsb.examples import sentences +>>> docs = nlp.pipe(sentences) +""" + + +sentences = [ + "Z tym stwori so wuměnjenje a zakład za dalše wobdźěłanje přez analyzu tekstoweje struktury a semantisku anotaciju a z tym tež za tu předstajenu digitalnu online-wersiju.", + "Mi so tu jara derje spodoba.", + "Kotre nowniny chceće měć?", + "Tak ako w slědnem lěśe jo teke lětosa jano doma zapustowaś móžno.", + "Zwóstanjo pótakem hyšći wjele źěła." +] diff --git a/spacy/lang/dsb/lex_attrs.py b/spacy/lang/dsb/lex_attrs.py new file mode 100644 index 000000000..75fb2e590 --- /dev/null +++ b/spacy/lang/dsb/lex_attrs.py @@ -0,0 +1,77 @@ +from ...attrs import LIKE_NUM + +_num_words = [ + "nul", + "jaden", "jadna", "jadno", + "dwa", "dwě", + "tśi", "tśo", + "styri", "styrjo", + "pěś", "pěśo", + "šesć", "šesćo", + "sedym", "sedymjo", + "wósym", "wósymjo", + "źewjeś", "źewjeśo", + "źaseś", "źaseśo", + "jadnassćo", + "dwanassćo", + "tśinasćo", + "styrnasćo", + "pěśnasćo", + "šesnasćo", + "sedymnasćo", + "wósymnasćo", + "źewjeśnasćo", + "dwanasćo", "dwaźasća", + "tśiźasća", + "styrźasća", + "pěśźaset", + "šesćźaset", + "sedymźaset", + "wósymźaset", + "źewjeśźaset", + "sto", + "tysac", + "milion", + "miliarda", + "bilion", + "biliarda", + "trilion", + "triliarda", +] + +_ordinal_words = [ + "prědny", "prědna", "prědne", + "drugi", "druga", "druge", + "tśeśi", "tśeśa", "tśeśe", + "stwórty", "stwórta", "stwórte", + "pêty", "pěta", "pête", + "šesty", "šesta", "šeste", + "sedymy", "sedyma", "sedyme", + "wósymy", "wósyma", "wósyme", + "źewjety", "źewjeta", "źewjete", + "źasety", "źaseta", "źasete", + "jadnasty", "jadnasta", "jadnaste", + "dwanasty", "dwanasta", "dwanaste" +] + + +def like_num(text): + if text.startswith(("+", "-", "±", "~")): + text = text[1:] + text = text.replace(",", "").replace(".", "") + if text.isdigit(): + return True + if text.count("/") == 1: + num, denom = text.split("/") + if num.isdigit() and denom.isdigit(): + return True + text_lower = text.lower() + if text_lower in _num_words: + return True + # Check ordinal number + if text_lower in _ordinal_words: + return True + return False + + +LEX_ATTRS = {LIKE_NUM: like_num} diff --git a/spacy/lang/dsb/stop_words.py b/spacy/lang/dsb/stop_words.py new file mode 100644 index 000000000..376e04aa6 --- /dev/null +++ b/spacy/lang/dsb/stop_words.py @@ -0,0 +1,15 @@ +STOP_WORDS = set( + """ +a abo aby ako ale až + +daniž dokulaž + +gaž + +jolic + +pak pótom + +teke togodla +""".split() +) diff --git a/spacy/tests/conftest.py b/spacy/tests/conftest.py index 7083fd817..24474c71e 100644 --- a/spacy/tests/conftest.py +++ b/spacy/tests/conftest.py @@ -99,6 +99,11 @@ def de_vocab(): return get_lang_class("de")().vocab +@pytest.fixture(scope="session") +def dsb_tokenizer(): + return get_lang_class("dsb")().tokenizer + + @pytest.fixture(scope="session") def el_tokenizer(): return get_lang_class("el")().tokenizer diff --git a/spacy/tests/lang/dsb/__init__.py b/spacy/tests/lang/dsb/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spacy/tests/lang/dsb/test_text.py b/spacy/tests/lang/dsb/test_text.py new file mode 100644 index 000000000..40f2c15e0 --- /dev/null +++ b/spacy/tests/lang/dsb/test_text.py @@ -0,0 +1,25 @@ +import pytest + + +@pytest.mark.parametrize( + "text,match", + [ + ("10", True), + ("1", True), + ("10,000", True), + ("10,00", True), + ("jadno", True), + ("dwanassćo", True), + ("milion", True), + ("sto", True), + ("ceła", False), + ("kopica", False), + ("narěcow", False), + (",", False), + ("1/2", True), + ], +) +def test_lex_attrs_like_number(dsb_tokenizer, text, match): + tokens = dsb_tokenizer(text) + assert len(tokens) == 1 + assert tokens[0].like_num == match diff --git a/spacy/tests/lang/dsb/test_tokenizer.py b/spacy/tests/lang/dsb/test_tokenizer.py new file mode 100644 index 000000000..135974fb8 --- /dev/null +++ b/spacy/tests/lang/dsb/test_tokenizer.py @@ -0,0 +1,29 @@ +import pytest + +DSB_BASIC_TOKENIZATION_TESTS = [ + ( + "Ale eksistěrujo mimo togo ceła kopica narěcow, ako na pśikład slěpjańska.", + [ + "Ale", + "eksistěrujo", + "mimo", + "togo", + "ceła", + "kopica", + "narěcow", + ",", + "ako", + "na", + "pśikład", + "slěpjańska", + ".", + ], + ), +] + + +@pytest.mark.parametrize("text,expected_tokens", DSB_BASIC_TOKENIZATION_TESTS) +def test_dsb_tokenizer_basic(dsb_tokenizer, text, expected_tokens): + tokens = dsb_tokenizer(text) + token_list = [token.text for token in tokens if not token.is_space] + assert expected_tokens == token_list From b2bbefd0b542fcad527b9badf97fd1c3c69a7bbf Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Mon, 7 Mar 2022 17:03:45 +0100 Subject: [PATCH 008/424] Add Finnish, Korean, and Swedish models and Korean support notes (#10355) * Add Finnish, Korean, and Swedish models to website * Add Korean language support notes --- website/docs/usage/models.md | 47 +++++++++++++++++++++++++++++++++--- website/meta/languages.json | 21 +++++++++++++--- 2 files changed, 61 insertions(+), 7 deletions(-) diff --git a/website/docs/usage/models.md b/website/docs/usage/models.md index 3b79c4d0d..f82da44d9 100644 --- a/website/docs/usage/models.md +++ b/website/docs/usage/models.md @@ -259,6 +259,45 @@ used for training the current [Japanese pipelines](/models/ja). +### Korean language support {#korean} + +> #### mecab-ko tokenizer +> +> ```python +> nlp = spacy.blank("ko") +> ``` + +The default MeCab-based Korean tokenizer requires: + +- [mecab-ko](https://bitbucket.org/eunjeon/mecab-ko/src/master/README.md) +- [mecab-ko-dic](https://bitbucket.org/eunjeon/mecab-ko-dic) +- [natto-py](https://github.com/buruzaemon/natto-py) + +For some Korean datasets and tasks, the +[rule-based tokenizer](/usage/linguistic-features#tokenization) is better-suited +than MeCab. To configure a Korean pipeline with the rule-based tokenizer: + +> #### Rule-based tokenizer +> +> ```python +> config = {"nlp": {"tokenizer": {"@tokenizers": "spacy.Tokenizer.v1"}}} +> nlp = spacy.blank("ko", config=config) +> ``` + +```ini +### config.cfg +[nlp] +lang = "ko" +tokenizer = {"@tokenizers" = "spacy.Tokenizer.v1"} +``` + + + +The [Korean trained pipelines](/models/ko) use the rule-based tokenizer, so no +additional dependencies are required. + + + ## Installing and using trained pipelines {#download} The easiest way to download a trained pipeline is via spaCy's @@ -417,10 +456,10 @@ doc = nlp("This is a sentence.") You can use the [`info`](/api/cli#info) command or -[`spacy.info()`](/api/top-level#spacy.info) method to print a pipeline -package's meta data before loading it. Each `Language` object with a loaded -pipeline also exposes the pipeline's meta data as the attribute `meta`. For -example, `nlp.meta['version']` will return the package version. +[`spacy.info()`](/api/top-level#spacy.info) method to print a pipeline package's +meta data before loading it. Each `Language` object with a loaded pipeline also +exposes the pipeline's meta data as the attribute `meta`. For example, +`nlp.meta['version']` will return the package version. diff --git a/website/meta/languages.json b/website/meta/languages.json index a7dda6482..1c4379b6d 100644 --- a/website/meta/languages.json +++ b/website/meta/languages.json @@ -114,7 +114,12 @@ { "code": "fi", "name": "Finnish", - "has_examples": true + "has_examples": true, + "models": [ + "fi_core_news_sm", + "fi_core_news_md", + "fi_core_news_lg" + ] }, { "code": "fr", @@ -227,7 +232,12 @@ } ], "example": "이것은 문장입니다.", - "has_examples": true + "has_examples": true, + "models": [ + "ko_core_news_sm", + "ko_core_news_md", + "ko_core_news_lg" + ] }, { "code": "ky", @@ -388,7 +398,12 @@ { "code": "sv", "name": "Swedish", - "has_examples": true + "has_examples": true, + "models": [ + "sv_core_news_sm", + "sv_core_news_md", + "sv_core_news_lg" + ] }, { "code": "ta", From 60520d86693699c1221a4414a133f76ffb9601b0 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Tue, 8 Mar 2022 13:51:11 +0100 Subject: [PATCH 009/424] Fix types in API docs for moves in parser and ner (#10464) --- website/docs/api/dependencyparser.md | 2 +- website/docs/api/entityrecognizer.md | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/website/docs/api/dependencyparser.md b/website/docs/api/dependencyparser.md index 118cdc611..103e0826e 100644 --- a/website/docs/api/dependencyparser.md +++ b/website/docs/api/dependencyparser.md @@ -100,7 +100,7 @@ shortcut for this and instantiate the component using its string name and | `vocab` | The shared vocabulary. ~~Vocab~~ | | `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. ~~Model[List[Doc], List[Floats2d]]~~ | | `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ | -| `moves` | A list of transition names. Inferred from the data if not provided. ~~Optional[List[str]]~~ | +| `moves` | A list of transition names. Inferred from the data if not provided. ~~Optional[TransitionSystem]~~ | | _keyword-only_ | | | `update_with_oracle_cut_size` | During training, cut long sequences into shorter segments by creating intermediate states based on the gold-standard history. The model is not very sensitive to this parameter, so you usually won't need to change it. Defaults to `100`. ~~int~~ | | `learn_tokens` | Whether to learn to merge subtokens that are split relative to the gold standard. Experimental. Defaults to `False`. ~~bool~~ | diff --git a/website/docs/api/entityrecognizer.md b/website/docs/api/entityrecognizer.md index 14b6fece4..7c153f064 100644 --- a/website/docs/api/entityrecognizer.md +++ b/website/docs/api/entityrecognizer.md @@ -62,7 +62,7 @@ architectures and their arguments and hyperparameters. | Setting | Description | | ----------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `moves` | A list of transition names. Inferred from the data if not provided. Defaults to `None`. ~~Optional[List[str]]~~ | +| `moves` | A list of transition names. Inferred from the data if not provided. Defaults to `None`. ~~Optional[TransitionSystem]~~ | | `update_with_oracle_cut_size` | During training, cut long sequences into shorter segments by creating intermediate states based on the gold-standard history. The model is not very sensitive to this parameter, so you usually won't need to change it. Defaults to `100`. ~~int~~ | | `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [TransitionBasedParser](/api/architectures#TransitionBasedParser). ~~Model[List[Doc], List[Floats2d]]~~ | | `incorrect_spans_key` | This key refers to a `SpanGroup` in `doc.spans` that specifies incorrect spans. The NER will learn not to predict (exactly) those spans. Defaults to `None`. ~~Optional[str]~~ | @@ -98,7 +98,7 @@ shortcut for this and instantiate the component using its string name and | `vocab` | The shared vocabulary. ~~Vocab~~ | | `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. ~~Model[List[Doc], List[Floats2d]]~~ | | `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ | -| `moves` | A list of transition names. Inferred from the data if set to `None`, which is the default. ~~Optional[List[str]]~~ | +| `moves` | A list of transition names. Inferred from the data if set to `None`, which is the default. ~~Optional[TransitionSystem]~~ | | _keyword-only_ | | | `update_with_oracle_cut_size` | During training, cut long sequences into shorter segments by creating intermediate states based on the gold-standard history. The model is not very sensitive to this parameter, so you usually won't need to change it. Defaults to `100`. ~~int~~ | | `incorrect_spans_key` | Identifies spans that are known to be incorrect entity annotations. The incorrect entity annotations can be stored in the span group in [`Doc.spans`](/api/doc#spans), under this key. Defaults to `None`. ~~Optional[str]~~ | From 191e8b31fa75f60b32f9e4779fe629b3c31e7c5e Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Tue, 8 Mar 2022 14:28:46 +0100 Subject: [PATCH 010/424] Remove English tokenizer exception May. (#10463) --- spacy/lang/en/tokenizer_exceptions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/spacy/lang/en/tokenizer_exceptions.py b/spacy/lang/en/tokenizer_exceptions.py index 55b544e42..2c20b8c27 100644 --- a/spacy/lang/en/tokenizer_exceptions.py +++ b/spacy/lang/en/tokenizer_exceptions.py @@ -447,7 +447,6 @@ for exc_data in [ {ORTH: "La.", NORM: "Louisiana"}, {ORTH: "Mar.", NORM: "March"}, {ORTH: "Mass.", NORM: "Massachusetts"}, - {ORTH: "May.", NORM: "May"}, {ORTH: "Mich.", NORM: "Michigan"}, {ORTH: "Minn.", NORM: "Minnesota"}, {ORTH: "Miss.", NORM: "Mississippi"}, From 01ec6349eab7fd1d426a29bd6b9546826fb38bfa Mon Sep 17 00:00:00 2001 From: Peter Baumgartner <5107405+pmbaumgartner@users.noreply.github.com> Date: Tue, 8 Mar 2022 10:04:10 -0500 Subject: [PATCH 011/424] Add `path.mkdir` to custom component examples of `to_disk` (#10348) * add `path.mkdir` to examples * add ensure_path + mkdir * update highlights --- website/docs/usage/processing-pipelines.md | 6 +++++- website/docs/usage/saving-loading.md | 12 +++++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/website/docs/usage/processing-pipelines.md b/website/docs/usage/processing-pipelines.md index 11fd1459d..9e6ee54df 100644 --- a/website/docs/usage/processing-pipelines.md +++ b/website/docs/usage/processing-pipelines.md @@ -1081,13 +1081,17 @@ on [serialization methods](/usage/saving-loading/#serialization-methods). > directory. ```python -### Custom serialization methods {highlight="6-7,9-11"} +### Custom serialization methods {highlight="7-11,13-15"} import srsly +from spacy.util import ensure_path class AcronymComponent: # other methods here... def to_disk(self, path, exclude=tuple()): + path = ensure_path(path) + if not path.exists(): + path.mkdir() srsly.write_json(path / "data.json", self.data) def from_disk(self, path, exclude=tuple()): diff --git a/website/docs/usage/saving-loading.md b/website/docs/usage/saving-loading.md index 9dad077e7..af140e7a7 100644 --- a/website/docs/usage/saving-loading.md +++ b/website/docs/usage/saving-loading.md @@ -202,7 +202,9 @@ the data to and from a JSON file. > rules _with_ the component data. ```python -### {highlight="14-18,20-25"} +### {highlight="16-23,25-30"} +from spacy.util import ensure_path + @Language.factory("my_component") class CustomComponent: def __init__(self): @@ -218,6 +220,9 @@ class CustomComponent: def to_disk(self, path, exclude=tuple()): # This will receive the directory path + /my_component + path = ensure_path(path) + if not path.exists(): + path.mkdir() data_path = path / "data.json" with data_path.open("w", encoding="utf8") as f: f.write(json.dumps(self.data)) @@ -467,7 +472,12 @@ pipeline package. When you save out a pipeline using `nlp.to_disk` and the component exposes a `to_disk` method, it will be called with the disk path. ```python +from spacy.util import ensure_path + def to_disk(self, path, exclude=tuple()): + path = ensure_path(path) + if not path.exists(): + path.mkdir() snek_path = path / "snek.txt" with snek_path.open("w", encoding="utf8") as snek_file: snek_file.write(self.snek) From 297dd82c86372c7aa0a181e55dc72512718aafe8 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Fri, 11 Mar 2022 10:50:47 +0100 Subject: [PATCH 012/424] Fix initial special cases for Tokenizer.explain (#10460) Add the missing initial check for special cases to `Tokenizer.explain` to align with `Tokenizer._tokenize_affixes`. --- spacy/tests/tokenizer/test_tokenizer.py | 13 +++++++++++ spacy/tokenizer.pyx | 4 ++++ website/docs/usage/linguistic-features.md | 28 ++++++++++++++--------- 3 files changed, 34 insertions(+), 11 deletions(-) diff --git a/spacy/tests/tokenizer/test_tokenizer.py b/spacy/tests/tokenizer/test_tokenizer.py index a7270cb1e..ed11508b4 100644 --- a/spacy/tests/tokenizer/test_tokenizer.py +++ b/spacy/tests/tokenizer/test_tokenizer.py @@ -521,3 +521,16 @@ def test_tokenizer_infix_prefix(en_vocab): assert tokens == ["±10", "%"] explain_tokens = [t[1] for t in tokenizer.explain("±10%")] assert tokens == explain_tokens + + +def test_tokenizer_initial_special_case_explain(en_vocab): + tokenizer = Tokenizer( + en_vocab, + token_match=re.compile("^id$").match, + rules={ + "id": [{"ORTH": "i"}, {"ORTH": "d"}], + } + ) + tokens = [t.text for t in tokenizer("id")] + explain_tokens = [t[1] for t in tokenizer.explain("id")] + assert tokens == explain_tokens diff --git a/spacy/tokenizer.pyx b/spacy/tokenizer.pyx index 91f228032..ac55a61f3 100644 --- a/spacy/tokenizer.pyx +++ b/spacy/tokenizer.pyx @@ -643,6 +643,10 @@ cdef class Tokenizer: for substring in text.split(): suffixes = [] while substring: + if substring in special_cases: + tokens.extend(("SPECIAL-" + str(i + 1), self.vocab.strings[e[ORTH]]) for i, e in enumerate(special_cases[substring])) + substring = '' + continue while prefix_search(substring) or suffix_search(substring): if token_match(substring): tokens.append(("TOKEN_MATCH", substring)) diff --git a/website/docs/usage/linguistic-features.md b/website/docs/usage/linguistic-features.md index f8baf5588..c3f25565a 100644 --- a/website/docs/usage/linguistic-features.md +++ b/website/docs/usage/linguistic-features.md @@ -799,6 +799,10 @@ def tokenizer_pseudo_code( for substring in text.split(): suffixes = [] while substring: + if substring in special_cases: + tokens.extend(special_cases[substring]) + substring = "" + continue while prefix_search(substring) or suffix_search(substring): if token_match(substring): tokens.append(substring) @@ -851,20 +855,22 @@ def tokenizer_pseudo_code( The algorithm can be summarized as follows: 1. Iterate over space-separated substrings. -2. Look for a token match. If there is a match, stop processing and keep this - token. -3. Check whether we have an explicitly defined special case for this substring. +2. Check whether we have an explicitly defined special case for this substring. If we do, use it. -4. Otherwise, try to consume one prefix. If we consumed a prefix, go back to #2, +3. Look for a token match. If there is a match, stop processing and keep this + token. +4. Check whether we have an explicitly defined special case for this substring. + If we do, use it. +5. Otherwise, try to consume one prefix. If we consumed a prefix, go back to #3, so that the token match and special cases always get priority. -5. If we didn't consume a prefix, try to consume a suffix and then go back to - #2. -6. If we can't consume a prefix or a suffix, look for a URL match. -7. If there's no URL match, then look for a special case. -8. Look for "infixes" – stuff like hyphens etc. and split the substring into +6. If we didn't consume a prefix, try to consume a suffix and then go back to + #3. +7. If we can't consume a prefix or a suffix, look for a URL match. +8. If there's no URL match, then look for a special case. +9. Look for "infixes" – stuff like hyphens etc. and split the substring into tokens on all infixes. -9. Once we can't consume any more of the string, handle it as a single token. -10. Make a final pass over the text to check for special cases that include +10. Once we can't consume any more of the string, handle it as a single token. +11. Make a final pass over the text to check for special cases that include spaces or that were missed due to the incremental processing of affixes. From 1bbf23207487da4463e8de96efdb2145b408823e Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 11 Mar 2022 12:20:23 +0100 Subject: [PATCH 013/424] Auto-format code with black (#10479) * Auto-format code with black * Update spacy/lang/hsb/lex_attrs.py Co-authored-by: explosion-bot Co-authored-by: Adriane Boyd --- spacy/lang/dsb/examples.py | 2 +- spacy/lang/dsb/lex_attrs.py | 82 ++++++++++++++++------ spacy/lang/hsb/examples.py | 2 +- spacy/lang/hsb/lex_attrs.py | 63 ++++++++++++----- spacy/tests/pipeline/test_entity_linker.py | 24 ++++--- 5 files changed, 121 insertions(+), 52 deletions(-) diff --git a/spacy/lang/dsb/examples.py b/spacy/lang/dsb/examples.py index 28b8c41f1..6e9143826 100644 --- a/spacy/lang/dsb/examples.py +++ b/spacy/lang/dsb/examples.py @@ -11,5 +11,5 @@ sentences = [ "Mi so tu jara derje spodoba.", "Kotre nowniny chceće měć?", "Tak ako w slědnem lěśe jo teke lětosa jano doma zapustowaś móžno.", - "Zwóstanjo pótakem hyšći wjele źěła." + "Zwóstanjo pótakem hyšći wjele źěła.", ] diff --git a/spacy/lang/dsb/lex_attrs.py b/spacy/lang/dsb/lex_attrs.py index 75fb2e590..367b3afb8 100644 --- a/spacy/lang/dsb/lex_attrs.py +++ b/spacy/lang/dsb/lex_attrs.py @@ -2,16 +2,27 @@ from ...attrs import LIKE_NUM _num_words = [ "nul", - "jaden", "jadna", "jadno", - "dwa", "dwě", - "tśi", "tśo", - "styri", "styrjo", - "pěś", "pěśo", - "šesć", "šesćo", - "sedym", "sedymjo", - "wósym", "wósymjo", - "źewjeś", "źewjeśo", - "źaseś", "źaseśo", + "jaden", + "jadna", + "jadno", + "dwa", + "dwě", + "tśi", + "tśo", + "styri", + "styrjo", + "pěś", + "pěśo", + "šesć", + "šesćo", + "sedym", + "sedymjo", + "wósym", + "wósymjo", + "źewjeś", + "źewjeśo", + "źaseś", + "źaseśo", "jadnassćo", "dwanassćo", "tśinasćo", @@ -21,7 +32,8 @@ _num_words = [ "sedymnasćo", "wósymnasćo", "źewjeśnasćo", - "dwanasćo", "dwaźasća", + "dwanasćo", + "dwaźasća", "tśiźasća", "styrźasća", "pěśźaset", @@ -40,18 +52,42 @@ _num_words = [ ] _ordinal_words = [ - "prědny", "prědna", "prědne", - "drugi", "druga", "druge", - "tśeśi", "tśeśa", "tśeśe", - "stwórty", "stwórta", "stwórte", - "pêty", "pěta", "pête", - "šesty", "šesta", "šeste", - "sedymy", "sedyma", "sedyme", - "wósymy", "wósyma", "wósyme", - "źewjety", "źewjeta", "źewjete", - "źasety", "źaseta", "źasete", - "jadnasty", "jadnasta", "jadnaste", - "dwanasty", "dwanasta", "dwanaste" + "prědny", + "prědna", + "prědne", + "drugi", + "druga", + "druge", + "tśeśi", + "tśeśa", + "tśeśe", + "stwórty", + "stwórta", + "stwórte", + "pêty", + "pěta", + "pête", + "šesty", + "šesta", + "šeste", + "sedymy", + "sedyma", + "sedyme", + "wósymy", + "wósyma", + "wósyme", + "źewjety", + "źewjeta", + "źewjete", + "źasety", + "źaseta", + "źasete", + "jadnasty", + "jadnasta", + "jadnaste", + "dwanasty", + "dwanasta", + "dwanaste", ] diff --git a/spacy/lang/hsb/examples.py b/spacy/lang/hsb/examples.py index 0aafd5cee..21f6f7584 100644 --- a/spacy/lang/hsb/examples.py +++ b/spacy/lang/hsb/examples.py @@ -11,5 +11,5 @@ sentences = [ "Jogo pśewóźowarce stej groniłej, až how w serbskich stronach njama Santa Claus nic pytaś.", "A ten sobuźěłaśeŕ Statneje biblioteki w Barlinju jo pśimjeł drogotne knigły bźez rukajcowu z nagima rukoma!", "Take wobchadanje z našym kulturnym derbstwom zewšym njejźo.", - "Wopśimjeśe drugich pśinoskow jo było na wusokem niwowje, ako pśecej." + "Wopśimjeśe drugich pśinoskow jo było na wusokem niwowje, ako pśecej.", ] diff --git a/spacy/lang/hsb/lex_attrs.py b/spacy/lang/hsb/lex_attrs.py index dfda3e2db..5f300a73d 100644 --- a/spacy/lang/hsb/lex_attrs.py +++ b/spacy/lang/hsb/lex_attrs.py @@ -2,10 +2,15 @@ from ...attrs import LIKE_NUM _num_words = [ "nul", - "jedyn", "jedna", "jedne", - "dwaj", "dwě", - "tři", "třo", - "štyri", "štyrjo", + "jedyn", + "jedna", + "jedne", + "dwaj", + "dwě", + "tři", + "třo", + "štyri", + "štyrjo", "pjeć", "šěsć", "sydom", @@ -21,7 +26,7 @@ _num_words = [ "sydomnaće", "wosomnaće", "dźewjatnaće", - "dwaceći" + "dwaceći", "třiceći", "štyrceći", "pjećdźesat", @@ -40,18 +45,42 @@ _num_words = [ ] _ordinal_words = [ - "prěni", "prěnja", "prěnje", - "druhi", "druha", "druhe", - "třeći", "třeća", "třeće", - "štwórty", "štwórta", "štwórte", - "pjaty", "pjata", "pjate", - "šěsty", "šěsta", "šěste", - "sydmy", "sydma", "sydme", - "wosmy", "wosma", "wosme", - "dźewjaty", "dźewjata", "dźewjate", - "dźesaty", "dźesata", "dźesate", - "jědnaty", "jědnata", "jědnate", - "dwanaty", "dwanata", "dwanate" + "prěni", + "prěnja", + "prěnje", + "druhi", + "druha", + "druhe", + "třeći", + "třeća", + "třeće", + "štwórty", + "štwórta", + "štwórte", + "pjaty", + "pjata", + "pjate", + "šěsty", + "šěsta", + "šěste", + "sydmy", + "sydma", + "sydme", + "wosmy", + "wosma", + "wosme", + "dźewjaty", + "dźewjata", + "dźewjate", + "dźesaty", + "dźesata", + "dźesate", + "jědnaty", + "jědnata", + "jědnate", + "dwanaty", + "dwanata", + "dwanate", ] diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index af2132d73..83d5bf0e2 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -1009,14 +1009,17 @@ def test_legacy_architectures(name, config): losses = {} nlp.update(train_examples, sgd=optimizer, losses=losses) -@pytest.mark.parametrize("patterns", [ - # perfect case - [{"label": "CHARACTER", "pattern": "Kirby"}], - # typo for false negative - [{"label": "PERSON", "pattern": "Korby"}], - # random stuff for false positive - [{"label": "IS", "pattern": "is"}, {"label": "COLOR", "pattern": "pink"}], - ] + +@pytest.mark.parametrize( + "patterns", + [ + # perfect case + [{"label": "CHARACTER", "pattern": "Kirby"}], + # typo for false negative + [{"label": "PERSON", "pattern": "Korby"}], + # random stuff for false positive + [{"label": "IS", "pattern": "is"}, {"label": "COLOR", "pattern": "pink"}], + ], ) def test_no_gold_ents(patterns): # test that annotating components work @@ -1055,9 +1058,10 @@ def test_no_gold_ents(patterns): mykb.add_alias("pink", ["pink"], [0.9]) return mykb - # Create and train the Entity Linker - entity_linker = nlp.add_pipe("entity_linker", config={"use_gold_ents": False}, last=True) + entity_linker = nlp.add_pipe( + "entity_linker", config={"use_gold_ents": False}, last=True + ) entity_linker.set_kb(create_kb) assert entity_linker.use_gold_ents == False From 6af6c2e86cc7b08573b261563786bd1ab87d45e9 Mon Sep 17 00:00:00 2001 From: Lj Miranda <12949683+ljvmiranda921@users.noreply.github.com> Date: Mon, 14 Mar 2022 16:41:31 +0800 Subject: [PATCH 014/424] Add a note to the dev docs on mypy (#10485) --- extra/DEVELOPER_DOCS/Code Conventions.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/extra/DEVELOPER_DOCS/Code Conventions.md b/extra/DEVELOPER_DOCS/Code Conventions.md index eba466c46..37cd8ff27 100644 --- a/extra/DEVELOPER_DOCS/Code Conventions.md +++ b/extra/DEVELOPER_DOCS/Code Conventions.md @@ -137,7 +137,7 @@ If any of the TODOs you've added are important and should be fixed soon, you sho ## Type hints -We use Python type hints across the `.py` files wherever possible. This makes it easy to understand what a function expects and returns, and modern editors will be able to show this information to you when you call an annotated function. Type hints are not currently used in the `.pyx` (Cython) code, except for definitions of registered functions and component factories, where they're used for config validation. +We use Python type hints across the `.py` files wherever possible. This makes it easy to understand what a function expects and returns, and modern editors will be able to show this information to you when you call an annotated function. Type hints are not currently used in the `.pyx` (Cython) code, except for definitions of registered functions and component factories, where they're used for config validation. Ideally when developing, run `mypy spacy` on the code base to inspect any issues. If possible, you should always use the more descriptive type hints like `List[str]` or even `List[Any]` instead of only `list`. We also annotate arguments and return types of `Callable` – although, you can simplify this if the type otherwise gets too verbose (e.g. functions that return factories to create callbacks). Remember that `Callable` takes two values: a **list** of the argument type(s) in order, and the return values. @@ -155,6 +155,13 @@ def create_callback(some_arg: bool) -> Callable[[str, int], List[str]]: return callback ``` +For typing variables, we prefer the explicit format. + +```diff +- var = value # type: Type ++ var: Type = value +``` + For model architectures, Thinc also provides a collection of [custom types](https://thinc.ai/docs/api-types), including more specific types for arrays and model inputs/outputs. Even outside of static type checking, using these types will make the code a lot easier to read and follow, since it's always clear what array types are expected (and what might go wrong if the output is different from the expected type). ```python From 23bc93d3d286ca050ae18a9e120331d94454229d Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Mon, 14 Mar 2022 15:17:22 +0100 Subject: [PATCH 015/424] limit pytest to <7.1 (#10488) * limit pytest to <7.1 * 7.1.0 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index b8970f686..a034dec27 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,7 +26,7 @@ typing_extensions>=3.7.4.1,<4.0.0.0; python_version < "3.8" # Development dependencies pre-commit>=2.13.0 cython>=0.25,<3.0 -pytest>=5.2.0 +pytest>=5.2.0,<7.1.0 pytest-timeout>=1.3.0,<2.0.0 mock>=2.0.0,<3.0.0 flake8>=3.8.0,<3.10.0 From b68bf43f5bf07b78c062777f35240f031374fe00 Mon Sep 17 00:00:00 2001 From: Edward <43848523+thomashacker@users.noreply.github.com> Date: Mon, 14 Mar 2022 15:47:57 +0100 Subject: [PATCH 016/424] Add spans to doc.to_json (#10073) * Add spans to to_json * adjustments to_json * Change docstring * change doc key naming * Update spacy/tokens/doc.pyx Co-authored-by: Adriane Boyd Co-authored-by: Adriane Boyd --- spacy/tests/doc/test_to_json.py | 12 +++++++++++- spacy/tokens/doc.pyx | 11 ++++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/spacy/tests/doc/test_to_json.py b/spacy/tests/doc/test_to_json.py index 9ebee6c88..202281654 100644 --- a/spacy/tests/doc/test_to_json.py +++ b/spacy/tests/doc/test_to_json.py @@ -1,5 +1,5 @@ import pytest -from spacy.tokens import Doc +from spacy.tokens import Doc, Span @pytest.fixture() @@ -60,3 +60,13 @@ def test_doc_to_json_underscore_error_serialize(doc): Doc.set_extension("json_test4", method=lambda doc: doc.text) with pytest.raises(ValueError): doc.to_json(underscore=["json_test4"]) + + +def test_doc_to_json_span(doc): + """Test that Doc.to_json() includes spans""" + doc.spans["test"] = [Span(doc, 0, 2, "test"), Span(doc, 0, 1, "test")] + json_doc = doc.to_json() + assert "spans" in json_doc + assert len(json_doc["spans"]) == 1 + assert len(json_doc["spans"]["test"]) == 2 + assert json_doc["spans"]["test"][0]["start"] == 0 diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index d33764ac9..1a48705fd 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -1457,7 +1457,7 @@ cdef class Doc: underscore (list): Optional list of string names of custom doc._. attributes. Attribute values need to be JSON-serializable. Values will be added to an "_" key in the data, e.g. "_": {"foo": "bar"}. - RETURNS (dict): The data in spaCy's JSON format. + RETURNS (dict): The data in JSON format. """ data = {"text": self.text} if self.has_annotation("ENT_IOB"): @@ -1486,6 +1486,15 @@ cdef class Doc: token_data["dep"] = token.dep_ token_data["head"] = token.head.i data["tokens"].append(token_data) + + if self.spans: + data["spans"] = {} + for span_group in self.spans: + data["spans"][span_group] = [] + for span in self.spans[span_group]: + span_data = {"start": span.start_char, "end": span.end_char, "label": span.label_, "kb_id": span.kb_id_} + data["spans"][span_group].append(span_data) + if underscore: data["_"] = {} for attr in underscore: From 2eef47dd26a5acbc3f667a2bc3b1ddf16f2d1b07 Mon Sep 17 00:00:00 2001 From: Edward <43848523+thomashacker@users.noreply.github.com> Date: Mon, 14 Mar 2022 16:46:58 +0100 Subject: [PATCH 017/424] Save span candidates produced by spancat suggesters (#10413) * Add save_candidates attribute * Change spancat api * Add unit test * reimplement method to produce a list of doc * Add method to docs * Add new version tag * Add intended use to docstring * prettier formatting --- spacy/pipeline/spancat.py | 18 ++++++++++++++++++ spacy/tests/pipeline/test_spancat.py | 22 ++++++++++++++++++++++ website/docs/api/spancategorizer.md | 18 ++++++++++++++++++ 3 files changed, 58 insertions(+) diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py index 3759466d1..0a6138fbc 100644 --- a/spacy/pipeline/spancat.py +++ b/spacy/pipeline/spancat.py @@ -272,6 +272,24 @@ class SpanCategorizer(TrainablePipe): scores = self.model.predict((docs, indices)) # type: ignore return indices, scores + def set_candidates( + self, docs: Iterable[Doc], *, candidates_key: str = "candidates" + ) -> None: + """Use the spancat suggester to add a list of span candidates to a list of docs. + This method is intended to be used for debugging purposes. + + docs (Iterable[Doc]): The documents to modify. + candidates_key (str): Key of the Doc.spans dict to save the candidate spans under. + + DOCS: https://spacy.io/api/spancategorizer#set_candidates + """ + suggester_output = self.suggester(docs, ops=self.model.ops) + + for candidates, doc in zip(suggester_output, docs): # type: ignore + doc.spans[candidates_key] = [] + for index in candidates.dataXd: + doc.spans[candidates_key].append(doc[index[0] : index[1]]) + def set_annotations(self, docs: Iterable[Doc], indices_scores) -> None: """Modify a batch of Doc objects, using pre-computed scores. diff --git a/spacy/tests/pipeline/test_spancat.py b/spacy/tests/pipeline/test_spancat.py index 8060bc621..15256a763 100644 --- a/spacy/tests/pipeline/test_spancat.py +++ b/spacy/tests/pipeline/test_spancat.py @@ -397,3 +397,25 @@ def test_zero_suggestions(): assert set(spancat.labels) == {"LOC", "PERSON"} nlp.update(train_examples, sgd=optimizer) + + +def test_set_candidates(): + nlp = Language() + spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) + train_examples = make_examples(nlp) + nlp.initialize(get_examples=lambda: train_examples) + texts = [ + "Just a sentence.", + "I like London and Berlin", + "I like Berlin", + "I eat ham.", + ] + + docs = [nlp(text) for text in texts] + spancat.set_candidates(docs) + + assert len(docs) == len(texts) + assert type(docs[0].spans["candidates"]) == SpanGroup + assert len(docs[0].spans["candidates"]) == 9 + assert docs[0].spans["candidates"][0].text == "Just" + assert docs[0].spans["candidates"][4].text == "Just a" diff --git a/website/docs/api/spancategorizer.md b/website/docs/api/spancategorizer.md index 26fcaefdf..fc666aaf7 100644 --- a/website/docs/api/spancategorizer.md +++ b/website/docs/api/spancategorizer.md @@ -239,6 +239,24 @@ Delegates to [`predict`](/api/spancategorizer#predict) and | `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ | | **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | +## SpanCategorizer.set_candidates {#set_candidates tag="method", new="3.3"} + +Use the suggester to add a list of [`Span`](/api/span) candidates to a list of +[`Doc`](/api/doc) objects. This method is intended to be used for debugging +purposes. + +> #### Example +> +> ```python +> spancat = nlp.add_pipe("spancat") +> spancat.set_candidates(docs, "candidates") +> ``` + +| Name | Description | +| ---------------- | -------------------------------------------------------------------- | +| `docs` | The documents to modify. ~~Iterable[Doc]~~ | +| `candidates_key` | Key of the Doc.spans dict to save the candidate spans under. ~~str~~ | + ## SpanCategorizer.get_loss {#get_loss tag="method"} Find the loss and gradient of loss for the batch of documents and their From 0dc454ba9577262ba23279e66f5ea384dd6677fb Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Tue, 15 Mar 2022 09:10:47 +0100 Subject: [PATCH 018/424] Update docs for Vocab.get_vector (#10486) * Update docs for Vocab.get_vector * Clarify description of 0-vector dimensions --- spacy/vocab.pyx | 5 +++-- website/docs/api/vocab.md | 9 +++------ 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/spacy/vocab.pyx b/spacy/vocab.pyx index badd291ed..58036fffa 100644 --- a/spacy/vocab.pyx +++ b/spacy/vocab.pyx @@ -354,8 +354,9 @@ cdef class Vocab: def get_vector(self, orth): """Retrieve a vector for a word in the vocabulary. Words can be looked - up by string or int ID. If no vectors data is loaded, ValueError is - raised. + up by string or int ID. If the current vectors do not contain an entry + for the word, a 0-vector with the same number of dimensions as the + current vectors is returned. orth (int / unicode): The hash value of a word, or its unicode string. RETURNS (numpy.ndarray or cupy.ndarray): A word vector. Size diff --git a/website/docs/api/vocab.md b/website/docs/api/vocab.md index c0a269d95..4698c68c3 100644 --- a/website/docs/api/vocab.md +++ b/website/docs/api/vocab.md @@ -168,22 +168,19 @@ cosines are calculated in minibatches to reduce memory usage. ## Vocab.get_vector {#get_vector tag="method" new="2"} Retrieve a vector for a word in the vocabulary. Words can be looked up by string -or hash value. If no vectors data is loaded, a `ValueError` is raised. If `minn` -is defined, then the resulting vector uses [FastText](https://fasttext.cc/)'s -subword features by average over n-grams of `orth` (introduced in spaCy `v2.1`). +or hash value. If the current vectors do not contain an entry for the word, a +0-vector with the same number of dimensions +([`Vocab.vectors_length`](#attributes)) as the current vectors is returned. > #### Example > > ```python > nlp.vocab.get_vector("apple") -> nlp.vocab.get_vector("apple", minn=1, maxn=5) > ``` | Name | Description | | ----------------------------------- | ---------------------------------------------------------------------------------------------------------------------- | | `orth` | The hash value of a word, or its unicode string. ~~Union[int, str]~~ | -| `minn` 2.1 | Minimum n-gram length used for FastText's n-gram computation. Defaults to the length of `orth`. ~~int~~ | -| `maxn` 2.1 | Maximum n-gram length used for FastText's n-gram computation. Defaults to the length of `orth`. ~~int~~ | | **RETURNS** | A word vector. Size and shape are determined by the `Vocab.vectors` instance. ~~numpy.ndarray[ndim=1, dtype=float32]~~ | ## Vocab.set_vector {#set_vector tag="method" new="2"} From 610001e8c724ee57fec301469454d80e955385a8 Mon Sep 17 00:00:00 2001 From: vincent d warmerdam Date: Tue, 15 Mar 2022 11:12:04 +0100 Subject: [PATCH 019/424] Update universe.json (#10490) The project moved away from Rasa and into my personal GitHub account. --- website/meta/universe.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/website/meta/universe.json b/website/meta/universe.json index 0179830d0..e178eab1f 100644 --- a/website/meta/universe.json +++ b/website/meta/universe.json @@ -377,10 +377,10 @@ "title": "whatlies", "slogan": "Make interactive visualisations to figure out 'what lies' in word embeddings.", "description": "This small library offers tools to make visualisation easier of both word embeddings as well as operations on them. It has support for spaCy prebuilt models as a first class citizen but also offers support for sense2vec. There's a convenient API to perform linear algebra as well as support for popular transformations like PCA/UMAP/etc.", - "github": "rasahq/whatlies", + "github": "koaning/whatlies", "pip": "whatlies", "thumb": "https://i.imgur.com/rOkOiLv.png", - "image": "https://raw.githubusercontent.com/RasaHQ/whatlies/master/docs/gif-two.gif", + "image": "https://raw.githubusercontent.com/koaning/whatlies/master/docs/gif-two.gif", "code_example": [ "from whatlies import EmbeddingSet", "from whatlies.language import SpacyLanguage", From e8357923ec873e5a66129a0ee84e05d42e9234cb Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Tue, 15 Mar 2022 11:12:50 +0100 Subject: [PATCH 020/424] Various install docs updates (#10487) * Simplify quickstart source install to use only editable pip install * Update pytorch install instructions to more recent versions --- website/docs/usage/embeddings-transformers.md | 12 ++++++------ website/src/widgets/quickstart-install.js | 9 +-------- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/website/docs/usage/embeddings-transformers.md b/website/docs/usage/embeddings-transformers.md index 708cdd8bf..70fa95099 100644 --- a/website/docs/usage/embeddings-transformers.md +++ b/website/docs/usage/embeddings-transformers.md @@ -211,23 +211,23 @@ PyTorch as a dependency below, but it may not find the best version for your setup. ```bash -### Example: Install PyTorch 1.7.1 for CUDA 10.1 with pip +### Example: Install PyTorch 1.11.0 for CUDA 11.3 with pip # See: https://pytorch.org/get-started/locally/ -$ pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html +$ pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html ``` Next, install spaCy with the extras for your CUDA version and transformers. The -CUDA extra (e.g., `cuda92`, `cuda102`, `cuda111`) installs the correct version -of [`cupy`](https://docs.cupy.dev/en/stable/install.html#installing-cupy), which +CUDA extra (e.g., `cuda102`, `cuda113`) installs the correct version of +[`cupy`](https://docs.cupy.dev/en/stable/install.html#installing-cupy), which is just like `numpy`, but for GPU. You may also need to set the `CUDA_PATH` environment variable if your CUDA runtime is installed in a non-standard -location. Putting it all together, if you had installed CUDA 10.2 in +location. Putting it all together, if you had installed CUDA 11.3 in `/opt/nvidia/cuda`, you would run: ```bash ### Installation with CUDA $ export CUDA_PATH="/opt/nvidia/cuda" -$ pip install -U %%SPACY_PKG_NAME[cuda102,transformers]%%SPACY_PKG_FLAGS +$ pip install -U %%SPACY_PKG_NAME[cuda113,transformers]%%SPACY_PKG_FLAGS ``` For [`transformers`](https://huggingface.co/transformers/) v4.0.0+ and models diff --git a/website/src/widgets/quickstart-install.js b/website/src/widgets/quickstart-install.js index 1c8ad19da..fbf043c7d 100644 --- a/website/src/widgets/quickstart-install.js +++ b/website/src/widgets/quickstart-install.js @@ -214,16 +214,9 @@ const QuickstartInstall = ({ id, title }) => { {nightly ? ` --branch ${DEFAULT_BRANCH}` : ''} cd spaCy - - export PYTHONPATH=`pwd` - - - set PYTHONPATH=C:\path\to\spaCy - pip install -r requirements.txt - python setup.py build_ext --inplace - pip install {train || hardware == 'gpu' ? `'.[${pipExtras}]'` : '.'} + pip install --no-build-isolation --editable {train || hardware == 'gpu' ? `'.[${pipExtras}]'` : '.'} # packages only available via pip From e5debc68e4910384351938f574ede7c9b35a2a5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 15 Mar 2022 14:15:31 +0100 Subject: [PATCH 021/424] Tagger: use unnormalized probabilities for inference (#10197) * Tagger: use unnormalized probabilities for inference Using unnormalized softmax avoids use of the relatively expensive exp function, which can significantly speed up non-transformer models (e.g. I got a speedup of 27% on a German tagging + parsing pipeline). * Add spacy.Tagger.v2 with configurable normalization Normalization of probabilities is disabled by default to improve performance. * Update documentation, models, and tests to spacy.Tagger.v2 * Move Tagger.v1 to spacy-legacy * docs/architectures: run prettier * Unnormalized softmax is now a Softmax_v2 option * Require thinc 8.0.14 and spacy-legacy 3.0.9 --- pyproject.toml | 2 +- requirements.txt | 2 +- setup.cfg | 4 ++-- spacy/cli/templates/quickstart_training.jinja | 8 +++---- spacy/ml/models/tagger.py | 10 +++++---- spacy/pipeline/morphologizer.pyx | 2 +- spacy/pipeline/senter.pyx | 2 +- spacy/pipeline/tagger.pyx | 2 +- spacy/tests/pipeline/test_tok2vec.py | 6 +++--- .../tests/serialize/test_serialize_config.py | 4 ++-- .../serialize/test_serialize_language.py | 2 +- spacy/tests/training/test_pretraining.py | 6 +++--- spacy/tests/training/test_training.py | 2 +- website/docs/api/architectures.md | 21 ++++++++++++++----- 14 files changed, 43 insertions(+), 30 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f81484d43..a43b4c814 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ requires = [ "cymem>=2.0.2,<2.1.0", "preshed>=3.0.2,<3.1.0", "murmurhash>=0.28.0,<1.1.0", - "thinc>=8.0.12,<8.1.0", + "thinc>=8.0.14,<8.1.0", "blis>=0.4.0,<0.8.0", "pathy", "numpy>=1.15.0", diff --git a/requirements.txt b/requirements.txt index a034dec27..4da6d5df6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ spacy-legacy>=3.0.9,<3.1.0 spacy-loggers>=1.0.0,<2.0.0 cymem>=2.0.2,<2.1.0 preshed>=3.0.2,<3.1.0 -thinc>=8.0.12,<8.1.0 +thinc>=8.0.14,<8.1.0 blis>=0.4.0,<0.8.0 ml_datasets>=0.2.0,<0.3.0 murmurhash>=0.28.0,<1.1.0 diff --git a/setup.cfg b/setup.cfg index ed3bf63ce..3c5ba884a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,7 +38,7 @@ setup_requires = cymem>=2.0.2,<2.1.0 preshed>=3.0.2,<3.1.0 murmurhash>=0.28.0,<1.1.0 - thinc>=8.0.12,<8.1.0 + thinc>=8.0.14,<8.1.0 install_requires = # Our libraries spacy-legacy>=3.0.9,<3.1.0 @@ -46,7 +46,7 @@ install_requires = murmurhash>=0.28.0,<1.1.0 cymem>=2.0.2,<2.1.0 preshed>=3.0.2,<3.1.0 - thinc>=8.0.12,<8.1.0 + thinc>=8.0.14,<8.1.0 blis>=0.4.0,<0.8.0 wasabi>=0.8.1,<1.1.0 srsly>=2.4.1,<3.0.0 diff --git a/spacy/cli/templates/quickstart_training.jinja b/spacy/cli/templates/quickstart_training.jinja index da533b767..b84fb3a8f 100644 --- a/spacy/cli/templates/quickstart_training.jinja +++ b/spacy/cli/templates/quickstart_training.jinja @@ -54,7 +54,7 @@ stride = 96 factory = "morphologizer" [components.morphologizer.model] -@architectures = "spacy.Tagger.v1" +@architectures = "spacy.Tagger.v2" nO = null [components.morphologizer.model.tok2vec] @@ -70,7 +70,7 @@ grad_factor = 1.0 factory = "tagger" [components.tagger.model] -@architectures = "spacy.Tagger.v1" +@architectures = "spacy.Tagger.v2" nO = null [components.tagger.model.tok2vec] @@ -238,7 +238,7 @@ maxout_pieces = 3 factory = "morphologizer" [components.morphologizer.model] -@architectures = "spacy.Tagger.v1" +@architectures = "spacy.Tagger.v2" nO = null [components.morphologizer.model.tok2vec] @@ -251,7 +251,7 @@ width = ${components.tok2vec.model.encode.width} factory = "tagger" [components.tagger.model] -@architectures = "spacy.Tagger.v1" +@architectures = "spacy.Tagger.v2" nO = null [components.tagger.model.tok2vec] diff --git a/spacy/ml/models/tagger.py b/spacy/ml/models/tagger.py index 9c7fe042d..9f8ef7b2b 100644 --- a/spacy/ml/models/tagger.py +++ b/spacy/ml/models/tagger.py @@ -1,14 +1,14 @@ from typing import Optional, List -from thinc.api import zero_init, with_array, Softmax, chain, Model +from thinc.api import zero_init, with_array, Softmax_v2, chain, Model from thinc.types import Floats2d from ...util import registry from ...tokens import Doc -@registry.architectures("spacy.Tagger.v1") +@registry.architectures("spacy.Tagger.v2") def build_tagger_model( - tok2vec: Model[List[Doc], List[Floats2d]], nO: Optional[int] = None + tok2vec: Model[List[Doc], List[Floats2d]], nO: Optional[int] = None, normalize=False ) -> Model[List[Doc], List[Floats2d]]: """Build a tagger model, using a provided token-to-vector component. The tagger model simply adds a linear layer with softmax activation to predict scores @@ -19,7 +19,9 @@ def build_tagger_model( """ # TODO: glorot_uniform_init seems to work a bit better than zero_init here?! t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None - output_layer = Softmax(nO, t2v_width, init_W=zero_init) + output_layer = Softmax_v2( + nO, t2v_width, init_W=zero_init, normalize_outputs=normalize + ) softmax = with_array(output_layer) # type: ignore model = chain(tok2vec, softmax) model.set_ref("tok2vec", tok2vec) diff --git a/spacy/pipeline/morphologizer.pyx b/spacy/pipeline/morphologizer.pyx index 73d3799b1..24f98508f 100644 --- a/spacy/pipeline/morphologizer.pyx +++ b/spacy/pipeline/morphologizer.pyx @@ -25,7 +25,7 @@ BACKWARD_EXTEND = False default_model_config = """ [model] -@architectures = "spacy.Tagger.v1" +@architectures = "spacy.Tagger.v2" [model.tok2vec] @architectures = "spacy.Tok2Vec.v2" diff --git a/spacy/pipeline/senter.pyx b/spacy/pipeline/senter.pyx index 6d00e829d..6808fe70e 100644 --- a/spacy/pipeline/senter.pyx +++ b/spacy/pipeline/senter.pyx @@ -20,7 +20,7 @@ BACKWARD_OVERWRITE = False default_model_config = """ [model] -@architectures = "spacy.Tagger.v1" +@architectures = "spacy.Tagger.v2" [model.tok2vec] @architectures = "spacy.HashEmbedCNN.v2" diff --git a/spacy/pipeline/tagger.pyx b/spacy/pipeline/tagger.pyx index e21a9096e..d6ecbf084 100644 --- a/spacy/pipeline/tagger.pyx +++ b/spacy/pipeline/tagger.pyx @@ -27,7 +27,7 @@ BACKWARD_OVERWRITE = False default_model_config = """ [model] -@architectures = "spacy.Tagger.v1" +@architectures = "spacy.Tagger.v2" [model.tok2vec] @architectures = "spacy.HashEmbedCNN.v2" diff --git a/spacy/tests/pipeline/test_tok2vec.py b/spacy/tests/pipeline/test_tok2vec.py index a5ac85e1e..37104c78a 100644 --- a/spacy/tests/pipeline/test_tok2vec.py +++ b/spacy/tests/pipeline/test_tok2vec.py @@ -100,7 +100,7 @@ cfg_string = """ factory = "tagger" [components.tagger.model] - @architectures = "spacy.Tagger.v1" + @architectures = "spacy.Tagger.v2" nO = null [components.tagger.model.tok2vec] @@ -263,7 +263,7 @@ cfg_string_multi = """ factory = "tagger" [components.tagger.model] - @architectures = "spacy.Tagger.v1" + @architectures = "spacy.Tagger.v2" nO = null [components.tagger.model.tok2vec] @@ -373,7 +373,7 @@ cfg_string_multi_textcat = """ factory = "tagger" [components.tagger.model] - @architectures = "spacy.Tagger.v1" + @architectures = "spacy.Tagger.v2" nO = null [components.tagger.model.tok2vec] diff --git a/spacy/tests/serialize/test_serialize_config.py b/spacy/tests/serialize/test_serialize_config.py index 1d50fd1d1..85e6f8b2c 100644 --- a/spacy/tests/serialize/test_serialize_config.py +++ b/spacy/tests/serialize/test_serialize_config.py @@ -59,7 +59,7 @@ subword_features = true factory = "tagger" [components.tagger.model] -@architectures = "spacy.Tagger.v1" +@architectures = "spacy.Tagger.v2" [components.tagger.model.tok2vec] @architectures = "spacy.Tok2VecListener.v1" @@ -110,7 +110,7 @@ subword_features = true factory = "tagger" [components.tagger.model] -@architectures = "spacy.Tagger.v1" +@architectures = "spacy.Tagger.v2" [components.tagger.model.tok2vec] @architectures = "spacy.Tok2VecListener.v1" diff --git a/spacy/tests/serialize/test_serialize_language.py b/spacy/tests/serialize/test_serialize_language.py index 6e7fa0e4e..c03287548 100644 --- a/spacy/tests/serialize/test_serialize_language.py +++ b/spacy/tests/serialize/test_serialize_language.py @@ -70,7 +70,7 @@ factory = "ner" factory = "tagger" [components.tagger.model] -@architectures = "spacy.Tagger.v1" +@architectures = "spacy.Tagger.v2" nO = null [components.tagger.model.tok2vec] diff --git a/spacy/tests/training/test_pretraining.py b/spacy/tests/training/test_pretraining.py index 8ee54b544..9359c8485 100644 --- a/spacy/tests/training/test_pretraining.py +++ b/spacy/tests/training/test_pretraining.py @@ -38,7 +38,7 @@ subword_features = true factory = "tagger" [components.tagger.model] -@architectures = "spacy.Tagger.v1" +@architectures = "spacy.Tagger.v2" [components.tagger.model.tok2vec] @architectures = "spacy.Tok2VecListener.v1" @@ -62,7 +62,7 @@ pipeline = ["tagger"] factory = "tagger" [components.tagger.model] -@architectures = "spacy.Tagger.v1" +@architectures = "spacy.Tagger.v2" [components.tagger.model.tok2vec] @architectures = "spacy.HashEmbedCNN.v1" @@ -106,7 +106,7 @@ subword_features = true factory = "tagger" [components.tagger.model] -@architectures = "spacy.Tagger.v1" +@architectures = "spacy.Tagger.v2" [components.tagger.model.tok2vec] @architectures = "spacy.Tok2VecListener.v1" diff --git a/spacy/tests/training/test_training.py b/spacy/tests/training/test_training.py index 0d73300d8..f1f8ce9d4 100644 --- a/spacy/tests/training/test_training.py +++ b/spacy/tests/training/test_training.py @@ -241,7 +241,7 @@ maxout_pieces = 3 factory = "tagger" [components.tagger.model] -@architectures = "spacy.Tagger.v1" +@architectures = "spacy.Tagger.v2" nO = null [components.tagger.model.tok2vec] diff --git a/website/docs/api/architectures.md b/website/docs/api/architectures.md index 5fb3546a7..2bddcb28c 100644 --- a/website/docs/api/architectures.md +++ b/website/docs/api/architectures.md @@ -104,7 +104,7 @@ consisting of a CNN and a layer-normalized maxout activation function. > factory = "tagger" > > [components.tagger.model] -> @architectures = "spacy.Tagger.v1" +> @architectures = "spacy.Tagger.v2" > > [components.tagger.model.tok2vec] > @architectures = "spacy.Tok2VecListener.v1" @@ -158,8 +158,8 @@ be configured with the `attrs` argument. The suggested attributes are `NORM`, `PREFIX`, `SUFFIX` and `SHAPE`. This lets the model take into account some subword information, without construction a fully character-based representation. If pretrained vectors are available, they can be included in the -representation as well, with the vectors table kept static (i.e. it's -not updated). +representation as well, with the vectors table kept static (i.e. it's not +updated). | Name | Description | | ------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | @@ -613,14 +613,15 @@ same signature, but the `use_upper` argument was `True` by default. ## Tagging architectures {#tagger source="spacy/ml/models/tagger.py"} -### spacy.Tagger.v1 {#Tagger} +### spacy.Tagger.v2 {#Tagger} > #### Example Config > > ```ini > [model] -> @architectures = "spacy.Tagger.v1" +> @architectures = "spacy.Tagger.v2" > nO = null +> normalize = false > > [model.tok2vec] > # ... @@ -634,8 +635,18 @@ the token vectors. | ----------- | ------------------------------------------------------------------------------------------ | | `tok2vec` | Subnetwork to map tokens into vector representations. ~~Model[List[Doc], List[Floats2d]]~~ | | `nO` | The number of tags to output. Inferred from the data if `None`. ~~Optional[int]~~ | +| `normalize` | Normalize probabilities during inference. Defaults to `False`. ~~bool~~ | | **CREATES** | The model using the architecture. ~~Model[List[Doc], List[Floats2d]]~~ | + + +- The `normalize` argument was added in `spacy.Tagger.v2`. `spacy.Tagger.v1` + always normalizes probabilities during inference. + +The other arguments are shared between all versions. + + + ## Text classification architectures {#textcat source="spacy/ml/models/textcat.py"} A text classification architecture needs to take a [`Doc`](/api/doc) as input, From e021dc6279621ccdb00bd69961d12a19e47218a1 Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Tue, 15 Mar 2022 16:42:33 +0100 Subject: [PATCH 022/424] Updated explenation for for classy classification (#10484) * Update universe.json added classy-classification to Spacy universe * Update universe.json added classy-classification to the spacy universe resources * Update universe.json corrected a small typo in json * Update website/meta/universe.json Co-authored-by: Sofie Van Landeghem * Update website/meta/universe.json Co-authored-by: Sofie Van Landeghem * Update website/meta/universe.json Co-authored-by: Sofie Van Landeghem * Update universe.json processed merge feedback * Update universe.json * updated information for Classy Classificaiton Made a more comprehensible and easy description for Classy Classification based on feedback of Philip Vollet to prepare for sharing. * added note about examples * corrected for wrong formatting changes * Update website/meta/universe.json with small typo correction Co-authored-by: Adriane Boyd * resolved another typo * Update website/meta/universe.json Co-authored-by: Sofie Van Landeghem Co-authored-by: Sofie Van Landeghem Co-authored-by: Adriane Boyd --- website/meta/universe.json | 43 +++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/website/meta/universe.json b/website/meta/universe.json index e178eab1f..a930363a4 100644 --- a/website/meta/universe.json +++ b/website/meta/universe.json @@ -2601,8 +2601,9 @@ }, { "id": "classyclassification", - "slogan": "A Python library for classy few-shot and zero-shot classification within spaCy.", - "description": "Huggingface does offer some nice models for few/zero-shot classification, but these are not tailored to multi-lingual approaches. Rasa NLU has a nice approach for this, but its too embedded in their codebase for easy usage outside of Rasa/chatbots. Additionally, it made sense to integrate sentence-transformers and Huggingface zero-shot, instead of default word embeddings. Finally, I decided to integrate with spaCy, since training a custom spaCy TextCategorizer seems like a lot of hassle if you want something quick and dirty.", + "title": "Classy Classification", + "slogan": "Have you ever struggled with needing a spaCy TextCategorizer but didn't have the time to train one from scratch? Classy Classification is the way to go!", + "description": "Have you ever struggled with needing a [spaCy TextCategorizer](https://spacy.io/api/textcategorizer) but didn't have the time to train one from scratch? Classy Classification is the way to go! For few-shot classification using [sentence-transformers](https://github.com/UKPLab/sentence-transformers) or [spaCy models](https://spacy.io/usage/models), provide a dictionary with labels and examples, or just provide a list of labels for zero shot-classification with [Huggingface zero-shot classifiers](https://huggingface.co/models?pipeline_tag=zero-shot-classification).", "github": "davidberenstein1957/classy-classification", "pip": "classy-classification", "code_example": [ @@ -2618,32 +2619,36 @@ " \"Do you also have some ovens.\"]", "}", "", + "# see github repo for examples on sentence-transformers and Huggingface", "nlp = spacy.load('en_core_web_md')", - "", - "classification_type = \"spacy_few_shot\"", - "if classification_type == \"spacy_few_shot\":", - " nlp.add_pipe(\"text_categorizer\", ", - " config={\"data\": data, \"model\": \"spacy\"}", - " )", - "elif classification_type == \"sentence_transformer_few_shot\":", - " nlp.add_pipe(\"text_categorizer\", ", - " config={\"data\": data, \"model\": \"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2\"}", - " )", - "elif classification_type == \"huggingface_zero_shot\":", - " nlp.add_pipe(\"text_categorizer\", ", - " config={\"data\": list(data.keys()), \"cat_type\": \"zero\", \"model\": \"facebook/bart-large-mnli\"}", - " )", + "nlp.add_pipe(\"text_categorizer\", ", + " config={", + " \"data\": data,", + " \"model\": \"spacy\"", + " }", + ")", "", "print(nlp(\"I am looking for kitchen appliances.\")._.cats)", - "print([doc._.cats for doc in nlp.pipe([\"I am looking for kitchen appliances.\"])])" + "# Output:", + "#", + "# [{\"label\": \"furniture\", \"score\": 0.21}, {\"label\": \"kitchen\", \"score\": 0.79}]" ], "author": "David Berenstein", "author_links": { "github": "davidberenstein1957", "website": "https://www.linkedin.com/in/david-berenstein-1bab11105/" }, - "category": ["pipeline", "standalone"], - "tags": ["classification", "zero-shot", "few-shot", "sentence-transformers", "huggingface"], + "category": [ + "pipeline", + "standalone" + ], + "tags": [ + "classification", + "zero-shot", + "few-shot", + "sentence-transformers", + "huggingface" + ], "spacy_version": 3 }, { From a79cd3542b3dd667d8a97293462e22ed26a04ee5 Mon Sep 17 00:00:00 2001 From: Lj Miranda <12949683+ljvmiranda921@users.noreply.github.com> Date: Thu, 17 Mar 2022 01:14:34 +0800 Subject: [PATCH 023/424] Add displacy support for overlapping Spans (#10332) * Fix docstring for EntityRenderer * Add warning in displacy if doc.spans are empty * Implement parse_spans converter One notable change here is that the default spans_key is sc, and it's set by the user through the options. * Implement SpanRenderer Here, I implemented a SpanRenderer that looks similar to the EntityRenderer except for some templates. The spans_key, by default, is set to sc, but can be configured in the options (see parse_spans). The way I rendered these spans is per-token, i.e., I first check if each token (1) belongs to a given span type and (2) a starting token of a given span type. Once I have this information, I render them into the markup. * Fix mypy issues on typing * Add tests for displacy spans support * Update colors from RGB to hex Co-authored-by: Ines Montani * Remove unnecessary CSS properties * Add documentation for website * Remove unnecesasry scripts * Update wording on the documentation Co-authored-by: Sofie Van Landeghem * Put typing dependency on top of file * Put back z-index so that spans overlap properly * Make warning more explicit for spans_key Co-authored-by: Ines Montani Co-authored-by: Sofie Van Landeghem --- spacy/displacy/__init__.py | 41 +++- spacy/displacy/render.py | 179 +++++++++++++++++- spacy/displacy/templates.py | 49 +++++ spacy/errors.py | 4 + spacy/tests/test_displacy.py | 86 +++++++++ website/docs/api/top-level.md | 32 +++- website/docs/images/displacy-span-custom.html | 31 +++ website/docs/images/displacy-span.html | 41 ++++ website/docs/usage/visualizers.md | 53 ++++++ 9 files changed, 501 insertions(+), 15 deletions(-) create mode 100644 website/docs/images/displacy-span-custom.html create mode 100644 website/docs/images/displacy-span.html diff --git a/spacy/displacy/__init__.py b/spacy/displacy/__init__.py index 25d530c83..aa00c95d8 100644 --- a/spacy/displacy/__init__.py +++ b/spacy/displacy/__init__.py @@ -4,10 +4,10 @@ spaCy's built in visualization suite for dependencies and named entities. DOCS: https://spacy.io/api/top-level#displacy USAGE: https://spacy.io/usage/visualizers """ -from typing import Union, Iterable, Optional, Dict, Any, Callable +from typing import List, Union, Iterable, Optional, Dict, Any, Callable import warnings -from .render import DependencyRenderer, EntityRenderer +from .render import DependencyRenderer, EntityRenderer, SpanRenderer from ..tokens import Doc, Span from ..errors import Errors, Warnings from ..util import is_in_jupyter @@ -44,6 +44,7 @@ def render( factories = { "dep": (DependencyRenderer, parse_deps), "ent": (EntityRenderer, parse_ents), + "span": (SpanRenderer, parse_spans), } if style not in factories: raise ValueError(Errors.E087.format(style=style)) @@ -203,6 +204,42 @@ def parse_ents(doc: Doc, options: Dict[str, Any] = {}) -> Dict[str, Any]: return {"text": doc.text, "ents": ents, "title": title, "settings": settings} +def parse_spans(doc: Doc, options: Dict[str, Any] = {}) -> Dict[str, Any]: + """Generate spans in [{start: i, end: i, label: 'label'}] format. + + doc (Doc): Document to parse. + options (Dict[str, any]): Span-specific visualisation options. + RETURNS (dict): Generated span types keyed by text (original text) and spans. + """ + kb_url_template = options.get("kb_url_template", None) + spans_key = options.get("spans_key", "sc") + spans = [ + { + "start": span.start_char, + "end": span.end_char, + "start_token": span.start, + "end_token": span.end, + "label": span.label_, + "kb_id": span.kb_id_ if span.kb_id_ else "", + "kb_url": kb_url_template.format(span.kb_id_) if kb_url_template else "#", + } + for span in doc.spans[spans_key] + ] + tokens = [token.text for token in doc] + + if not spans: + warnings.warn(Warnings.W117.format(spans_key=spans_key)) + title = doc.user_data.get("title", None) if hasattr(doc, "user_data") else None + settings = get_doc_settings(doc) + return { + "text": doc.text, + "spans": spans, + "title": title, + "settings": settings, + "tokens": tokens, + } + + def set_render_wrapper(func: Callable[[str], str]) -> None: """Set an optional wrapper function that is called around the generated HTML markup on displacy.render. This can be used to allow integration into diff --git a/spacy/displacy/render.py b/spacy/displacy/render.py index a032d843b..2925c68a0 100644 --- a/spacy/displacy/render.py +++ b/spacy/displacy/render.py @@ -1,12 +1,15 @@ -from typing import Dict, Any, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import uuid +import itertools -from .templates import TPL_DEP_SVG, TPL_DEP_WORDS, TPL_DEP_WORDS_LEMMA, TPL_DEP_ARCS -from .templates import TPL_ENT, TPL_ENT_RTL, TPL_FIGURE, TPL_TITLE, TPL_PAGE -from .templates import TPL_ENTS, TPL_KB_LINK -from ..util import minify_html, escape_html, registry from ..errors import Errors - +from ..util import escape_html, minify_html, registry +from .templates import TPL_DEP_ARCS, TPL_DEP_SVG, TPL_DEP_WORDS +from .templates import TPL_DEP_WORDS_LEMMA, TPL_ENT, TPL_ENT_RTL, TPL_ENTS +from .templates import TPL_FIGURE, TPL_KB_LINK, TPL_PAGE, TPL_SPAN +from .templates import TPL_SPAN_RTL, TPL_SPAN_SLICE, TPL_SPAN_SLICE_RTL +from .templates import TPL_SPAN_START, TPL_SPAN_START_RTL, TPL_SPANS +from .templates import TPL_TITLE DEFAULT_LANG = "en" DEFAULT_DIR = "ltr" @@ -33,6 +36,168 @@ DEFAULT_LABEL_COLORS = { } +class SpanRenderer: + """Render Spans as SVGs.""" + + style = "span" + + def __init__(self, options: Dict[str, Any] = {}) -> None: + """Initialise span renderer + + options (dict): Visualiser-specific options (colors, spans) + """ + # Set up the colors and overall look + colors = dict(DEFAULT_LABEL_COLORS) + user_colors = registry.displacy_colors.get_all() + for user_color in user_colors.values(): + if callable(user_color): + # Since this comes from the function registry, we want to make + # sure we support functions that *return* a dict of colors + user_color = user_color() + if not isinstance(user_color, dict): + raise ValueError(Errors.E925.format(obj=type(user_color))) + colors.update(user_color) + colors.update(options.get("colors", {})) + self.default_color = DEFAULT_ENTITY_COLOR + self.colors = {label.upper(): color for label, color in colors.items()} + + # Set up how the text and labels will be rendered + self.direction = DEFAULT_DIR + self.lang = DEFAULT_LANG + self.top_offset = options.get("top_offset", 40) + self.top_offset_step = options.get("top_offset_step", 17) + + # Set up which templates will be used + template = options.get("template") + if template: + self.span_template = template["span"] + self.span_slice_template = template["slice"] + self.span_start_template = template["start"] + else: + if self.direction == "rtl": + self.span_template = TPL_SPAN_RTL + self.span_slice_template = TPL_SPAN_SLICE_RTL + self.span_start_template = TPL_SPAN_START_RTL + else: + self.span_template = TPL_SPAN + self.span_slice_template = TPL_SPAN_SLICE + self.span_start_template = TPL_SPAN_START + + def render( + self, parsed: List[Dict[str, Any]], page: bool = False, minify: bool = False + ) -> str: + """Render complete markup. + + parsed (list): Dependency parses to render. + page (bool): Render parses wrapped as full HTML page. + minify (bool): Minify HTML markup. + RETURNS (str): Rendered HTML markup. + """ + rendered = [] + for i, p in enumerate(parsed): + if i == 0: + settings = p.get("settings", {}) + self.direction = settings.get("direction", DEFAULT_DIR) + self.lang = settings.get("lang", DEFAULT_LANG) + rendered.append(self.render_spans(p["tokens"], p["spans"], p.get("title"))) + + if page: + docs = "".join([TPL_FIGURE.format(content=doc) for doc in rendered]) + markup = TPL_PAGE.format(content=docs, lang=self.lang, dir=self.direction) + else: + markup = "".join(rendered) + if minify: + return minify_html(markup) + return markup + + def render_spans( + self, + tokens: List[str], + spans: List[Dict[str, Any]], + title: Optional[str], + ) -> str: + """Render span types in text. + + Spans are rendered per-token, this means that for each token, we check if it's part + of a span slice (a member of a span type) or a span start (the starting token of a + given span type). + + tokens (list): Individual tokens in the text + spans (list): Individual entity spans and their start, end, label, kb_id and kb_url. + title (str / None): Document title set in Doc.user_data['title']. + """ + per_token_info = [] + for idx, token in enumerate(tokens): + # Identify if a token belongs to a Span (and which) and if it's a + # start token of said Span. We'll use this for the final HTML render + token_markup: Dict[str, Any] = {} + token_markup["text"] = token + entities = [] + for span in spans: + ent = {} + if span["start_token"] <= idx < span["end_token"]: + ent["label"] = span["label"] + ent["is_start"] = True if idx == span["start_token"] else False + kb_id = span.get("kb_id", "") + kb_url = span.get("kb_url", "#") + ent["kb_link"] = ( + TPL_KB_LINK.format(kb_id=kb_id, kb_url=kb_url) if kb_id else "" + ) + entities.append(ent) + token_markup["entities"] = entities + per_token_info.append(token_markup) + + markup = self._render_markup(per_token_info) + markup = TPL_SPANS.format(content=markup, dir=self.direction) + if title: + markup = TPL_TITLE.format(title=title) + markup + return markup + + def _render_markup(self, per_token_info: List[Dict[str, Any]]) -> str: + """Render the markup from per-token information""" + markup = "" + for token in per_token_info: + entities = sorted(token["entities"], key=lambda d: d["label"]) + if entities: + slices = self._get_span_slices(token["entities"]) + starts = self._get_span_starts(token["entities"]) + markup += self.span_template.format( + text=token["text"], span_slices=slices, span_starts=starts + ) + else: + markup += escape_html(token["text"] + " ") + return markup + + def _get_span_slices(self, entities: List[Dict]) -> str: + """Get the rendered markup of all Span slices""" + span_slices = [] + for entity, step in zip(entities, itertools.count(step=self.top_offset_step)): + color = self.colors.get(entity["label"].upper(), self.default_color) + span_slice = self.span_slice_template.format( + bg=color, top_offset=self.top_offset + step + ) + span_slices.append(span_slice) + return "".join(span_slices) + + def _get_span_starts(self, entities: List[Dict]) -> str: + """Get the rendered markup of all Span start tokens""" + span_starts = [] + for entity, step in zip(entities, itertools.count(step=self.top_offset_step)): + color = self.colors.get(entity["label"].upper(), self.default_color) + span_start = ( + self.span_start_template.format( + bg=color, + top_offset=self.top_offset + step, + label=entity["label"], + kb_link=entity["kb_link"], + ) + if entity["is_start"] + else "" + ) + span_starts.append(span_start) + return "".join(span_starts) + + class DependencyRenderer: """Render dependency parses as SVGs.""" @@ -242,7 +407,7 @@ class EntityRenderer: style = "ent" def __init__(self, options: Dict[str, Any] = {}) -> None: - """Initialise dependency renderer. + """Initialise entity renderer. options (dict): Visualiser-specific options (colors, ents) """ diff --git a/spacy/displacy/templates.py b/spacy/displacy/templates.py index e7d3d4266..ff81e7a1d 100644 --- a/spacy/displacy/templates.py +++ b/spacy/displacy/templates.py @@ -62,6 +62,55 @@ TPL_ENT_RTL = """ """ +TPL_SPANS = """ +
{content}
+""" + +TPL_SPAN = """ + + {text} + {span_slices} + {span_starts} + +""" + +TPL_SPAN_SLICE = """ + + +""" + + +TPL_SPAN_START = """ + + + {label}{kb_link} + + + +""" + +TPL_SPAN_RTL = """ + + {text} + {span_slices} + {span_starts} + +""" + +TPL_SPAN_SLICE_RTL = """ + + +""" + +TPL_SPAN_START_RTL = """ + + + {label}{kb_link} + + +""" + + # Important: this needs to start with a space! TPL_KB_LINK = """ {kb_id} diff --git a/spacy/errors.py b/spacy/errors.py index 5399e489b..fe37351f7 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -192,6 +192,10 @@ class Warnings(metaclass=ErrorsWithCodes): W115 = ("Skipping {method}: the floret vector table cannot be modified. " "Vectors are calculated from character ngrams.") W116 = ("Unable to clean attribute '{attr}'.") + W117 = ("No spans to visualize found in Doc object with spans_key: '{spans_key}'. If this is " + "surprising to you, make sure the Doc was processed using a model " + "that supports span categorization, and check the `doc.spans[spans_key]` " + "property manually if necessary.") class Errors(metaclass=ErrorsWithCodes): diff --git a/spacy/tests/test_displacy.py b/spacy/tests/test_displacy.py index 392c95e42..ccad7e342 100644 --- a/spacy/tests/test_displacy.py +++ b/spacy/tests/test_displacy.py @@ -96,6 +96,92 @@ def test_issue5838(): assert found == 4 +def test_displacy_parse_spans(en_vocab): + """Test that spans on a Doc are converted into displaCy's format.""" + doc = Doc(en_vocab, words=["Welcome", "to", "the", "Bank", "of", "China"]) + doc.spans["sc"] = [Span(doc, 3, 6, "ORG"), Span(doc, 5, 6, "GPE")] + spans = displacy.parse_spans(doc) + assert isinstance(spans, dict) + assert spans["text"] == "Welcome to the Bank of China " + assert spans["spans"] == [ + { + "start": 15, + "end": 28, + "start_token": 3, + "end_token": 6, + "label": "ORG", + "kb_id": "", + "kb_url": "#", + }, + { + "start": 23, + "end": 28, + "start_token": 5, + "end_token": 6, + "label": "GPE", + "kb_id": "", + "kb_url": "#", + }, + ] + + +def test_displacy_parse_spans_with_kb_id_options(en_vocab): + """Test that spans with kb_id on a Doc are converted into displaCy's format""" + doc = Doc(en_vocab, words=["Welcome", "to", "the", "Bank", "of", "China"]) + doc.spans["sc"] = [ + Span(doc, 3, 6, "ORG", kb_id="Q790068"), + Span(doc, 5, 6, "GPE", kb_id="Q148"), + ] + + spans = displacy.parse_spans( + doc, {"kb_url_template": "https://wikidata.org/wiki/{}"} + ) + assert isinstance(spans, dict) + assert spans["text"] == "Welcome to the Bank of China " + assert spans["spans"] == [ + { + "start": 15, + "end": 28, + "start_token": 3, + "end_token": 6, + "label": "ORG", + "kb_id": "Q790068", + "kb_url": "https://wikidata.org/wiki/Q790068", + }, + { + "start": 23, + "end": 28, + "start_token": 5, + "end_token": 6, + "label": "GPE", + "kb_id": "Q148", + "kb_url": "https://wikidata.org/wiki/Q148", + }, + ] + + +def test_displacy_parse_spans_different_spans_key(en_vocab): + """Test that spans in a different spans key will be parsed""" + doc = Doc(en_vocab, words=["Welcome", "to", "the", "Bank", "of", "China"]) + doc.spans["sc"] = [Span(doc, 3, 6, "ORG"), Span(doc, 5, 6, "GPE")] + doc.spans["custom"] = [Span(doc, 3, 6, "BANK")] + spans = displacy.parse_spans(doc, options={"spans_key": "custom"}) + + assert isinstance(spans, dict) + assert spans["text"] == "Welcome to the Bank of China " + assert spans["spans"] == [ + { + "start": 15, + "end": 28, + "start_token": 3, + "end_token": 6, + "label": "BANK", + "kb_id": "", + "kb_url": "#", + } + ] + + def test_displacy_parse_ents(en_vocab): """Test that named entities on a Doc are converted into displaCy's format.""" doc = Doc(en_vocab, words=["But", "Google", "is", "starting", "from", "behind"]) diff --git a/website/docs/api/top-level.md b/website/docs/api/top-level.md index 1a3e9da46..6d7431f28 100644 --- a/website/docs/api/top-level.md +++ b/website/docs/api/top-level.md @@ -320,12 +320,31 @@ If a setting is not present in the options, the default value will be used. | `template` 2.2 | Optional template to overwrite the HTML used to render entity spans. Should be a format string and can use `{bg}`, `{text}` and `{label}`. See [`templates.py`](%%GITHUB_SPACY/spacy/displacy/templates.py) for examples. ~~Optional[str]~~ | | `kb_url_template` 3.2.1 | Optional template to construct the KB url for the entity to link to. Expects a python f-string format with single field to fill in. ~~Optional[str]~~ | -By default, displaCy comes with colors for all entity types used by -[spaCy's trained pipelines](/models). If you're using custom entity types, you -can use the `colors` setting to add your own colors for them. Your application -or pipeline package can also expose a -[`spacy_displacy_colors` entry point](/usage/saving-loading#entry-points-displacy) -to add custom labels and their colors automatically. + +#### Span Visualizer options {#displacy_options-span} + +> #### Example +> +> ```python +> options = {"spans_key": "sc"} +> displacy.serve(doc, style="span", options=options) +> ``` + +| Name | Description | +|-----------------|---------------------------------------------------------------------------------------------------------------------------------------------------------| +| `spans_key` | Which spans key to render spans from. Default is `"sc"`. ~~str~~ | +| `templates` | Dictionary containing the keys `"span"`, `"slice"`, and `"start"`. These dictate how the overall span, a span slice, and the starting token will be rendered. ~~Optional[Dict[str, str]~~ | +| `kb_url_template` | Optional template to construct the KB url for the entity to link to. Expects a python f-string format with single field to fill in ~~Optional[str]~~ | +| `colors` | Color overrides. Entity types should be mapped to color names or values. ~~Dict[str, str]~~ | + + +By default, displaCy comes with colors for all entity types used by [spaCy's +trained pipelines](/models) for both entity and span visualizer. If you're +using custom entity types, you can use the `colors` setting to add your own +colors for them. Your application or pipeline package can also expose a +[`spacy_displacy_colors` entry +point](/usage/saving-loading#entry-points-displacy) to add custom labels and +their colors automatically. By default, displaCy links to `#` for entities without a `kb_id` set on their span. If you wish to link an entity to their URL then consider using the @@ -335,6 +354,7 @@ span. If you wish to link an entity to their URL then consider using the should redirect you to their Wikidata page, in this case `https://www.wikidata.org/wiki/Q95`. + ## registry {#registry source="spacy/util.py" new="3"} spaCy's function registry extends diff --git a/website/docs/images/displacy-span-custom.html b/website/docs/images/displacy-span-custom.html new file mode 100644 index 000000000..97dd3b140 --- /dev/null +++ b/website/docs/images/displacy-span-custom.html @@ -0,0 +1,31 @@ +
+ Welcome to the + + Bank + + + + + BANK + + + + + of + + + + + China + + + + + . +
\ No newline at end of file diff --git a/website/docs/images/displacy-span.html b/website/docs/images/displacy-span.html new file mode 100644 index 000000000..9bbc6403c --- /dev/null +++ b/website/docs/images/displacy-span.html @@ -0,0 +1,41 @@ +
+ Welcome to the + + Bank + + + + + ORG + + + + + of + + + + + + China + + + + + + + GPE + + + + . +
\ No newline at end of file diff --git a/website/docs/usage/visualizers.md b/website/docs/usage/visualizers.md index 072718f91..f98c43224 100644 --- a/website/docs/usage/visualizers.md +++ b/website/docs/usage/visualizers.md @@ -167,6 +167,59 @@ This feature is especially handy if you're using displaCy to compare performance at different stages of a process, e.g. during training. Here you could use the title for a brief description of the text example and the number of iterations. +## Visualizing spans {#span} + +The span visualizer, `span`, highlights overlapping spans in a text. + +```python +### Span example +import spacy +from spacy import displacy +from spacy.tokens import Span + +text = "Welcome to the Bank of China." + +nlp = spacy.blank("en") +doc = nlp(text) + +doc.spans["sc"] = [ + Span(doc, 3, 6, "ORG"), + Span(doc, 5, 6, "GPE"), +] + +displacy.serve(doc, style="span") +``` + +import DisplacySpanHtml from 'images/displacy-span.html' + +