diff --git a/spacy/language.py b/spacy/language.py index ba244617e..b438936a6 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -843,7 +843,7 @@ class Language: *, config: Dict[str, Any] = SimpleFrozenDict(), validate: bool = True, - ) -> None: + ) -> Callable[[Doc], Doc]: """Replace a component in the pipeline. name (str): Name of the component to replace. @@ -852,6 +852,7 @@ class Language: component. Will be merged with default config, if available. validate (bool): Whether to validate the component config against the arguments and types expected by the factory. + RETURNS (Callable[[Doc], Doc]): The new pipeline component. DOCS: https://nightly.spacy.io/api/language#replace_pipe """ @@ -866,9 +867,11 @@ class Language: self.remove_pipe(name) if not len(self._components) or pipe_index == len(self._components): # we have no components to insert before/after, or we're replacing the last component - self.add_pipe(factory_name, name=name, config=config, validate=validate) + return self.add_pipe( + factory_name, name=name, config=config, validate=validate + ) else: - self.add_pipe( + return self.add_pipe( factory_name, name=name, before=pipe_index, @@ -1300,7 +1303,11 @@ class Language: kwargs.setdefault("batch_size", batch_size) # non-trainable components may have a pipe() implementation that refers to dummy # predict and set_annotations methods - if not hasattr(pipe, "pipe") or not hasattr(pipe, "is_trainable") or not pipe.is_trainable(): + if ( + not hasattr(pipe, "pipe") + or not hasattr(pipe, "is_trainable") + or not pipe.is_trainable() + ): docs = _pipe(docs, pipe, kwargs) else: docs = pipe.pipe(docs, **kwargs) @@ -1412,7 +1419,11 @@ class Language: kwargs.setdefault("batch_size", batch_size) # non-trainable components may have a pipe() implementation that refers to dummy # predict and set_annotations methods - if hasattr(proc, "pipe") and hasattr(proc, "is_trainable") and proc.is_trainable(): + if ( + hasattr(proc, "pipe") + and hasattr(proc, "is_trainable") + and proc.is_trainable() + ): f = functools.partial(proc.pipe, **kwargs) else: # Apply the function, but yield the doc diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index 2a5f3962d..eec591995 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -8,6 +8,7 @@ from thinc.api import set_dropout_rate import warnings from ..kb import KnowledgeBase, Candidate +from ..ml import empty_kb from ..tokens import Doc from .pipe import Pipe, deserialize_config from ..language import Language @@ -41,11 +42,11 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"] requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"], assigns=["token.ent_kb_id"], default_config={ - "kb_loader": {"@misc": "spacy.EmptyKB.v1", "entity_vector_length": 64}, "model": DEFAULT_NEL_MODEL, "labels_discard": [], "incl_prior": True, "incl_context": True, + "entity_vector_length": 64, "get_candidates": {"@misc": "spacy.CandidateGenerator.v1"}, }, default_score_weights={ @@ -58,11 +59,11 @@ def make_entity_linker( nlp: Language, name: str, model: Model, - kb_loader: Callable[[Vocab], KnowledgeBase], *, labels_discard: Iterable[str], incl_prior: bool, incl_context: bool, + entity_vector_length: int, get_candidates: Callable[[KnowledgeBase, "Span"], Iterable[Candidate]], ): """Construct an EntityLinker component. @@ -70,19 +71,21 @@ def make_entity_linker( model (Model[List[Doc], Floats2d]): A model that learns document vector representations. Given a batch of Doc objects, it should return a single array, with one row per item in the batch. - kb (KnowledgeBase): The knowledge-base to link entities to. labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction. incl_prior (bool): Whether or not to include prior probabilities from the KB in the model. incl_context (bool): Whether or not to include the local context in the model. + entity_vector_length (int): Size of encoding vectors in the KB. + get_candidates (Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]): Function that + produces a list of candidates, given a certain knowledge base and a textual mention. """ return EntityLinker( nlp.vocab, model, name, - kb_loader=kb_loader, labels_discard=labels_discard, incl_prior=incl_prior, incl_context=incl_context, + entity_vector_length=entity_vector_length, get_candidates=get_candidates, ) @@ -101,10 +104,10 @@ class EntityLinker(Pipe): model: Model, name: str = "entity_linker", *, - kb_loader: Callable[[Vocab], KnowledgeBase], labels_discard: Iterable[str], incl_prior: bool, incl_context: bool, + entity_vector_length: int, get_candidates: Callable[[KnowledgeBase, "Span"], Iterable[Candidate]], ) -> None: """Initialize an entity linker. @@ -113,10 +116,12 @@ class EntityLinker(Pipe): model (thinc.api.Model): The Thinc Model powering the pipeline component. name (str): The component instance name, used to add entries to the losses during training. - kb_loader (Callable[[Vocab], KnowledgeBase]): A function that creates a KnowledgeBase from a Vocab instance. labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction. incl_prior (bool): Whether or not to include prior probabilities from the KB in the model. incl_context (bool): Whether or not to include the local context in the model. + entity_vector_length (int): Size of encoding vectors in the KB. + get_candidates (Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]): Function that + produces a list of candidates, given a certain knowledge base and a textual mention. DOCS: https://nightly.spacy.io/api/entitylinker#init """ @@ -127,15 +132,23 @@ class EntityLinker(Pipe): "labels_discard": list(labels_discard), "incl_prior": incl_prior, "incl_context": incl_context, + "entity_vector_length": entity_vector_length, } - self.kb = kb_loader(self.vocab) self.get_candidates = get_candidates self.cfg = dict(cfg) self.distance = CosineDistance(normalize=False) # how many neightbour sentences to take into account self.n_sents = cfg.get("n_sents", 0) + # create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'. + self.kb = empty_kb(entity_vector_length)(self.vocab) - def _require_kb(self) -> None: + def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]): + """Define the KB of this pipe by providing a function that will + create it using this object's vocab.""" + self.kb = kb_loader(self.vocab) + self.cfg["entity_vector_length"] = self.kb.entity_vector_length + + def validate_kb(self) -> None: # Raise an error if the knowledge base is not initialized. if len(self.kb) == 0: raise ValueError(Errors.E139.format(name=self.name)) @@ -145,6 +158,7 @@ class EntityLinker(Pipe): get_examples: Callable[[], Iterable[Example]], *, nlp: Optional[Language] = None, + kb_loader: Callable[[Vocab], KnowledgeBase] = None, ): """Initialize the pipe for training, using a representative set of data examples. @@ -152,11 +166,16 @@ class EntityLinker(Pipe): get_examples (Callable[[], Iterable[Example]]): Function that returns a representative sample of gold-standard Example objects. nlp (Language): The current nlp object the component is part of. + kb_loader (Callable[[Vocab], KnowledgeBase]): A function that creates a KnowledgeBase from a Vocab instance. + Note that providing this argument, will overwrite all data accumulated in the current KB. + Use this only when loading a KB as-such from file. DOCS: https://nightly.spacy.io/api/entitylinker#initialize """ self._ensure_examples(get_examples) - self._require_kb() + if kb_loader is not None: + self.set_kb(kb_loader) + self.validate_kb() nO = self.kb.entity_vector_length doc_sample = [] vector_sample = [] @@ -192,7 +211,7 @@ class EntityLinker(Pipe): DOCS: https://nightly.spacy.io/api/entitylinker#update """ - self._require_kb() + self.validate_kb() if losses is None: losses = {} losses.setdefault(self.name, 0.0) @@ -303,7 +322,7 @@ class EntityLinker(Pipe): DOCS: https://nightly.spacy.io/api/entitylinker#predict """ - self._require_kb() + self.validate_kb() entity_count = 0 final_kb_ids = [] if not docs: diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index 66de54c06..e77be74ad 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -110,7 +110,7 @@ def test_kb_invalid_entity_vector(nlp): def test_kb_default(nlp): - """Test that the default (empty) KB is loaded when not providing a config""" + """Test that the default (empty) KB is loaded upon construction""" entity_linker = nlp.add_pipe("entity_linker", config={}) assert len(entity_linker.kb) == 0 assert entity_linker.kb.get_size_entities() == 0 @@ -122,7 +122,7 @@ def test_kb_default(nlp): def test_kb_custom_length(nlp): """Test that the default (empty) KB can be configured with a custom entity length""" entity_linker = nlp.add_pipe( - "entity_linker", config={"kb_loader": {"entity_vector_length": 35}} + "entity_linker", config={"entity_vector_length": 35} ) assert len(entity_linker.kb) == 0 assert entity_linker.kb.get_size_entities() == 0 @@ -130,18 +130,9 @@ def test_kb_custom_length(nlp): assert entity_linker.kb.entity_vector_length == 35 -def test_kb_undefined(nlp): - """Test that the EL can't train without defining a KB""" - entity_linker = nlp.add_pipe("entity_linker", config={}) - with pytest.raises(ValueError): - entity_linker.initialize(lambda: []) - - -def test_kb_empty(nlp): - """Test that the EL can't train with an empty KB""" - config = {"kb_loader": {"@misc": "spacy.EmptyKB.v1", "entity_vector_length": 342}} - entity_linker = nlp.add_pipe("entity_linker", config=config) - assert len(entity_linker.kb) == 0 +def test_kb_initialize_empty(nlp): + """Test that the EL can't initialize without examples""" + entity_linker = nlp.add_pipe("entity_linker") with pytest.raises(ValueError): entity_linker.initialize(lambda: []) @@ -201,24 +192,21 @@ def test_el_pipe_configuration(nlp): ruler = nlp.add_pipe("entity_ruler") ruler.add_patterns([pattern]) - @registry.misc.register("myAdamKB.v1") - def mykb() -> Callable[["Vocab"], KnowledgeBase]: - def create_kb(vocab): - kb = KnowledgeBase(vocab, entity_vector_length=1) - kb.add_entity(entity="Q2", freq=12, entity_vector=[2]) - kb.add_entity(entity="Q3", freq=5, entity_vector=[3]) - kb.add_alias( - alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1] - ) - return kb - - return create_kb + def create_kb(vocab): + kb = KnowledgeBase(vocab, entity_vector_length=1) + kb.add_entity(entity="Q2", freq=12, entity_vector=[2]) + kb.add_entity(entity="Q3", freq=5, entity_vector=[3]) + kb.add_alias( + alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1] + ) + return kb # run an EL pipe without a trained context encoder, to check the candidate generation step only - nlp.add_pipe( + entity_linker = nlp.add_pipe( "entity_linker", - config={"kb_loader": {"@misc": "myAdamKB.v1"}, "incl_context": False}, + config={"incl_context": False}, ) + entity_linker.set_kb(create_kb) # With the default get_candidates function, matching is case-sensitive text = "Douglas and douglas are not the same." doc = nlp(text) @@ -234,15 +222,15 @@ def test_el_pipe_configuration(nlp): return get_lowercased_candidates # replace the pipe with a new one with with a different candidate generator - nlp.replace_pipe( + entity_linker = nlp.replace_pipe( "entity_linker", "entity_linker", config={ - "kb_loader": {"@misc": "myAdamKB.v1"}, "incl_context": False, "get_candidates": {"@misc": "spacy.LowercaseCandidateGenerator.v1"}, }, ) + entity_linker.set_kb(create_kb) doc = nlp(text) assert doc[0].ent_kb_id_ == "Q2" assert doc[1].ent_kb_id_ == "" @@ -334,19 +322,15 @@ def test_preserving_links_asdoc(nlp): """Test that Span.as_doc preserves the existing entity links""" vector_length = 1 - @registry.misc.register("myLocationsKB.v1") - def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]: - def create_kb(vocab): - mykb = KnowledgeBase(vocab, entity_vector_length=vector_length) - # adding entities - mykb.add_entity(entity="Q1", freq=19, entity_vector=[1]) - mykb.add_entity(entity="Q2", freq=8, entity_vector=[1]) - # adding aliases - mykb.add_alias(alias="Boston", entities=["Q1"], probabilities=[0.7]) - mykb.add_alias(alias="Denver", entities=["Q2"], probabilities=[0.6]) - return mykb - - return create_kb + def create_kb(vocab): + mykb = KnowledgeBase(vocab, entity_vector_length=vector_length) + # adding entities + mykb.add_entity(entity="Q1", freq=19, entity_vector=[1]) + mykb.add_entity(entity="Q2", freq=8, entity_vector=[1]) + # adding aliases + mykb.add_alias(alias="Boston", entities=["Q1"], probabilities=[0.7]) + mykb.add_alias(alias="Denver", entities=["Q2"], probabilities=[0.6]) + return mykb # set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained) nlp.add_pipe("sentencizer") @@ -356,8 +340,9 @@ def test_preserving_links_asdoc(nlp): ] ruler = nlp.add_pipe("entity_ruler") ruler.add_patterns(patterns) - el_config = {"kb_loader": {"@misc": "myLocationsKB.v1"}, "incl_prior": False} - entity_linker = nlp.add_pipe("entity_linker", config=el_config, last=True) + config = {"incl_prior": False} + entity_linker = nlp.add_pipe("entity_linker", config=config, last=True) + entity_linker.set_kb(create_kb) nlp.initialize() assert entity_linker.model.get_dim("nO") == vector_length @@ -435,30 +420,26 @@ def test_overfitting_IO(): doc = nlp(text) train_examples.append(Example.from_dict(doc, annotation)) - @registry.misc.register("myOverfittingKB.v1") - def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]: - def create_kb(vocab): - # create artificial KB - assign same prior weight to the two russ cochran's - # Q2146908 (Russ Cochran): American golfer - # Q7381115 (Russ Cochran): publisher - mykb = KnowledgeBase(vocab, entity_vector_length=vector_length) - mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3]) - mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7]) - mykb.add_alias( - alias="Russ Cochran", - entities=["Q2146908", "Q7381115"], - probabilities=[0.5, 0.5], - ) - return mykb - - return create_kb + def create_kb(vocab): + # create artificial KB - assign same prior weight to the two russ cochran's + # Q2146908 (Russ Cochran): American golfer + # Q7381115 (Russ Cochran): publisher + mykb = KnowledgeBase(vocab, entity_vector_length=vector_length) + mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3]) + mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7]) + mykb.add_alias( + alias="Russ Cochran", + entities=["Q2146908", "Q7381115"], + probabilities=[0.5, 0.5], + ) + return mykb # Create the Entity Linker component and add it to the pipeline entity_linker = nlp.add_pipe( "entity_linker", - config={"kb_loader": {"@misc": "myOverfittingKB.v1"}}, last=True, ) + entity_linker.set_kb(create_kb) # train the NEL pipe optimizer = nlp.initialize(get_examples=lambda: train_examples) diff --git a/spacy/tests/regression/test_issue5230.py b/spacy/tests/regression/test_issue5230.py index 5e320996a..9fda413a3 100644 --- a/spacy/tests/regression/test_issue5230.py +++ b/spacy/tests/regression/test_issue5230.py @@ -71,17 +71,13 @@ def tagger(): def entity_linker(): nlp = Language() - @registry.misc.register("TestIssue5230KB.v1") - def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]: - def create_kb(vocab): - kb = KnowledgeBase(vocab, entity_vector_length=1) - kb.add_entity("test", 0.0, zeros((1, 1), dtype="f")) - return kb + def create_kb(vocab): + kb = KnowledgeBase(vocab, entity_vector_length=1) + kb.add_entity("test", 0.0, zeros((1, 1), dtype="f")) + return kb - return create_kb - - config = {"kb_loader": {"@misc": "TestIssue5230KB.v1"}} - entity_linker = nlp.add_pipe("entity_linker", config=config) + entity_linker = nlp.add_pipe("entity_linker") + entity_linker.set_kb(create_kb) # need to add model for two reasons: # 1. no model leads to error in serialization, # 2. the affected line is the one for model serialization diff --git a/spacy/tests/serialize/test_serialize_kb.py b/spacy/tests/serialize/test_serialize_kb.py index 63736418b..352c335ea 100644 --- a/spacy/tests/serialize/test_serialize_kb.py +++ b/spacy/tests/serialize/test_serialize_kb.py @@ -1,11 +1,12 @@ from typing import Callable from spacy import util -from spacy.lang.en import English -from spacy.util import ensure_path, registry +from spacy.util import ensure_path, registry, load_model_from_config from spacy.kb import KnowledgeBase +from thinc.api import Config from ..util import make_tempdir +from numpy import zeros def test_serialize_kb_disk(en_vocab): @@ -80,6 +81,28 @@ def _check_kb(kb): def test_serialize_subclassed_kb(): """Check that IO of a custom KB works fine as part of an EL pipe.""" + config_string = """ + [nlp] + lang = "en" + pipeline = ["entity_linker"] + + [components] + + [components.entity_linker] + factory = "entity_linker" + + [initialize] + + [initialize.components] + + [initialize.components.entity_linker] + + [initialize.components.entity_linker.kb_loader] + @misc = "spacy.CustomKB.v1" + entity_vector_length = 342 + custom_field = 666 + """ + class SubKnowledgeBase(KnowledgeBase): def __init__(self, vocab, entity_vector_length, custom_field): super().__init__(vocab, entity_vector_length) @@ -90,23 +113,21 @@ def test_serialize_subclassed_kb(): entity_vector_length: int, custom_field: int ) -> Callable[["Vocab"], KnowledgeBase]: def custom_kb_factory(vocab): - return SubKnowledgeBase( + kb = SubKnowledgeBase( vocab=vocab, entity_vector_length=entity_vector_length, custom_field=custom_field, ) + kb.add_entity("random_entity", 0.0, zeros(entity_vector_length)) + return kb return custom_kb_factory - nlp = English() - config = { - "kb_loader": { - "@misc": "spacy.CustomKB.v1", - "entity_vector_length": 342, - "custom_field": 666, - } - } - entity_linker = nlp.add_pipe("entity_linker", config=config) + config = Config().from_str(config_string) + nlp = load_model_from_config(config, auto_fill=True) + nlp.initialize() + + entity_linker = nlp.get_pipe("entity_linker") assert type(entity_linker.kb) == SubKnowledgeBase assert entity_linker.kb.entity_vector_length == 342 assert entity_linker.kb.custom_field == 666 @@ -116,6 +137,7 @@ def test_serialize_subclassed_kb(): nlp.to_disk(tmp_dir) nlp2 = util.load_model_from_path(tmp_dir) entity_linker2 = nlp2.get_pipe("entity_linker") - assert type(entity_linker2.kb) == SubKnowledgeBase + # After IO, the KB is the standard one + assert type(entity_linker2.kb) == KnowledgeBase assert entity_linker2.kb.entity_vector_length == 342 - assert entity_linker2.kb.custom_field == 666 + assert not hasattr(entity_linker2.kb, "custom_field") diff --git a/website/docs/api/language.md b/website/docs/api/language.md index 6257199c9..51e9a5e10 100644 --- a/website/docs/api/language.md +++ b/website/docs/api/language.md @@ -524,7 +524,7 @@ Get a pipeline component for a given component name. ## Language.replace_pipe {#replace_pipe tag="method" new="2"} -Replace a component in the pipeline. +Replace a component in the pipeline and return the new component. @@ -538,7 +538,7 @@ and instead expects the **name of a component factory** registered using > #### Example > > ```python -> nlp.replace_pipe("parser", my_custom_parser) +> new_parser = nlp.replace_pipe("parser", "my_custom_parser") > ``` | Name | Description | @@ -548,6 +548,7 @@ and instead expects the **name of a component factory** registered using | _keyword-only_ | | | `config` 3 | Optional config parameters to use for the new component. Will be merged with the `default_config` specified by the component factory. ~~Optional[Dict[str, Any]]~~ | | `validate` 3 | Whether to validate the component config and arguments against the types expected by the factory. Defaults to `True`. ~~bool~~ | +| **RETURNS** | The new pipeline component. ~~Callable[[Doc], Doc]~~ | ## Language.rename_pipe {#rename_pipe tag="method" new="2"} diff --git a/website/docs/usage/saving-loading.md b/website/docs/usage/saving-loading.md index c19ff39eb..968689baf 100644 --- a/website/docs/usage/saving-loading.md +++ b/website/docs/usage/saving-loading.md @@ -297,7 +297,7 @@ packages. This lets one application easily customize the behavior of another, by exposing an entry point in its `setup.py`. For a quick and fun intro to entry points in Python, check out [this excellent blog post](https://amir.rachum.com/blog/2017/07/28/python-entry-points/). -spaCy can load custom function from several different entry points to add +spaCy can load custom functions from several different entry points to add pipeline component factories, language classes and other settings. To make spaCy use your entry points, your package needs to expose them and it needs to be installed in the same environment – that's it. diff --git a/website/docs/usage/v3.md b/website/docs/usage/v3.md index 0f30029e7..d9ab00b97 100644 --- a/website/docs/usage/v3.md +++ b/website/docs/usage/v3.md @@ -395,7 +395,7 @@ type-check model definitions. For data validation, spaCy v3.0 adopts [`pydantic`](https://github.com/samuelcolvin/pydantic). It also powers the data validation of Thinc's [config system](https://thinc.ai/docs/usage-config), which -lets you to register **custom functions with typed arguments**, reference them +lets you register **custom functions with typed arguments**, reference them in your config and see validation errors if the argument values don't match.