diff --git a/spacy/language.py b/spacy/language.py index e3b2285fb..3a0ea783e 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -600,7 +600,6 @@ class Language: *, config: Optional[Dict[str, Any]] = SimpleFrozenDict(), raw_config: Optional[Config] = None, - init_config: Optional[Dict[str, Any]] = SimpleFrozenDict(), validate: bool = True, ) -> Callable[[Doc], Doc]: """Create a pipeline component. Mostly used internally. To create and @@ -612,9 +611,6 @@ class Language: config (Optional[Dict[str, Any]]): Config parameters to use for this component. Will be merged with default config, if available. raw_config (Optional[Config]): Internals: the non-interpolated config. - init_config (Optional[Dict[str, Any]]): Config parameters to use to - initialize this component. Will be used to update the internal - 'initialize' config. validate (bool): Whether to validate the component config against the arguments and types expected by the factory. RETURNS (Callable[[Doc], Doc]): The pipeline component. @@ -625,13 +621,9 @@ class Language: if not isinstance(config, dict): err = Errors.E962.format(style="config", name=name, cfg_type=type(config)) raise ValueError(err) - if not isinstance(init_config, dict): - err = Errors.E962.format(style="init_config", name=name, cfg_type=type(init_config)) raise ValueError(err) if not srsly.is_json_serializable(config): raise ValueError(Errors.E961.format(config=config)) - if not srsly.is_json_serializable(init_config): - raise ValueError(Errors.E961.format(config=init_config)) if not self.has_factory(factory_name): err = Errors.E002.format( name=factory_name, @@ -643,8 +635,6 @@ class Language: raise ValueError(err) pipe_meta = self.get_factory_meta(factory_name) config = config or {} - if init_config: - self._config["initialize"]["components"][name] = init_config # This is unideal, but the alternative would mean you always need to # specify the full config settings, which is not really viable. if pipe_meta.default_config: @@ -719,7 +709,6 @@ class Language: source: Optional["Language"] = None, config: Optional[Dict[str, Any]] = SimpleFrozenDict(), raw_config: Optional[Config] = None, - init_config: Optional[Dict[str, Any]] = SimpleFrozenDict(), validate: bool = True, ) -> Callable[[Doc], Doc]: """Add a component to the processing pipeline. Valid components are @@ -742,9 +731,6 @@ class Language: config (Optional[Dict[str, Any]]): Config parameters to use for this component. Will be merged with default config, if available. raw_config (Optional[Config]): Internals: the non-interpolated config. - init_config (Optional[Dict[str, Any]]): Config parameters to use to - initialize this component. Will be used to update the internal - 'initialize' config. validate (bool): Whether to validate the component config against the arguments and types expected by the factory. RETURNS (Callable[[Doc], Doc]): The pipeline component. @@ -778,7 +764,6 @@ class Language: name=name, config=config, raw_config=raw_config, - init_config=init_config, validate=validate, ) pipe_index = self._get_pipe_index(before, after, first, last) @@ -858,20 +843,17 @@ class Language: factory_name: str, *, config: Dict[str, Any] = SimpleFrozenDict(), - init_config: Dict[str, Any] = SimpleFrozenDict(), validate: bool = True, - ) -> None: + ) -> Callable[[Doc], Doc]: """Replace a component in the pipeline. name (str): Name of the component to replace. factory_name (str): Factory name of replacement component. config (Optional[Dict[str, Any]]): Config parameters to use for this component. Will be merged with default config, if available. - init_config (Optional[Dict[str, Any]]): Config parameters to use to - initialize this component. Will be used to update the internal - 'initialize' config. validate (bool): Whether to validate the component config against the arguments and types expected by the factory. + RETURNS (Callable[[Doc], Doc]): The new pipeline component. DOCS: https://nightly.spacy.io/api/language#replace_pipe """ @@ -886,14 +868,15 @@ class Language: self.remove_pipe(name) if not len(self._components) or pipe_index == len(self._components): # we have no components to insert before/after, or we're replacing the last component - self.add_pipe(factory_name, name=name, config=config, init_config=init_config, validate=validate) + return self.add_pipe( + factory_name, name=name, config=config, validate=validate + ) else: - self.add_pipe( + return self.add_pipe( factory_name, name=name, before=pipe_index, config=config, - init_config=init_config, validate=validate, ) @@ -1321,7 +1304,11 @@ class Language: kwargs.setdefault("batch_size", batch_size) # non-trainable components may have a pipe() implementation that refers to dummy # predict and set_annotations methods - if not hasattr(pipe, "pipe") or not hasattr(pipe, "is_trainable") or not pipe.is_trainable(): + if ( + not hasattr(pipe, "pipe") + or not hasattr(pipe, "is_trainable") + or not pipe.is_trainable() + ): docs = _pipe(docs, pipe, kwargs) else: docs = pipe.pipe(docs, **kwargs) @@ -1433,7 +1420,11 @@ class Language: kwargs.setdefault("batch_size", batch_size) # non-trainable components may have a pipe() implementation that refers to dummy # predict and set_annotations methods - if hasattr(proc, "pipe") and hasattr(proc, "is_trainable") and proc.is_trainable(): + if ( + hasattr(proc, "pipe") + and hasattr(proc, "is_trainable") + and proc.is_trainable() + ): f = functools.partial(proc.pipe, **kwargs) else: # Apply the function, but yield the doc diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index b371ca9a4..eec591995 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -142,6 +142,12 @@ class EntityLinker(Pipe): # create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'. self.kb = empty_kb(entity_vector_length)(self.vocab) + def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]): + """Define the KB of this pipe by providing a function that will + create it using this object's vocab.""" + self.kb = kb_loader(self.vocab) + self.cfg["entity_vector_length"] = self.kb.entity_vector_length + def validate_kb(self) -> None: # Raise an error if the knowledge base is not initialized. if len(self.kb) == 0: @@ -168,8 +174,7 @@ class EntityLinker(Pipe): """ self._ensure_examples(get_examples) if kb_loader is not None: - self.kb = kb_loader(self.vocab) - self.cfg["entity_vector_length"] = self.kb.entity_vector_length + self.set_kb(kb_loader) self.validate_kb() nO = self.kb.entity_vector_length doc_sample = [] diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index cf9fce2a7..e77be74ad 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -130,18 +130,9 @@ def test_kb_custom_length(nlp): assert entity_linker.kb.entity_vector_length == 35 -def test_kb_undefined(nlp): - """Test that the EL can't train without defining a KB""" - entity_linker = nlp.add_pipe("entity_linker", config={}) - with pytest.raises(ValueError): - entity_linker.initialize(lambda: []) - - -def test_kb_empty(nlp): - """Test that the EL can't train with an empty KB""" - config = {"kb_loader": {"@misc": "spacy.EmptyKB.v1", "entity_vector_length": 342}} - entity_linker = nlp.add_pipe("entity_linker", init_config=config) - assert len(entity_linker.kb) == 0 +def test_kb_initialize_empty(nlp): + """Test that the EL can't initialize without examples""" + entity_linker = nlp.add_pipe("entity_linker") with pytest.raises(ValueError): entity_linker.initialize(lambda: []) @@ -201,26 +192,21 @@ def test_el_pipe_configuration(nlp): ruler = nlp.add_pipe("entity_ruler") ruler.add_patterns([pattern]) - @registry.misc.register("myAdamKB.v1") - def mykb() -> Callable[["Vocab"], KnowledgeBase]: - def create_kb(vocab): - kb = KnowledgeBase(vocab, entity_vector_length=1) - kb.add_entity(entity="Q2", freq=12, entity_vector=[2]) - kb.add_entity(entity="Q3", freq=5, entity_vector=[3]) - kb.add_alias( - alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1] - ) - return kb - - return create_kb + def create_kb(vocab): + kb = KnowledgeBase(vocab, entity_vector_length=1) + kb.add_entity(entity="Q2", freq=12, entity_vector=[2]) + kb.add_entity(entity="Q3", freq=5, entity_vector=[3]) + kb.add_alias( + alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1] + ) + return kb # run an EL pipe without a trained context encoder, to check the candidate generation step only - nlp.add_pipe( + entity_linker = nlp.add_pipe( "entity_linker", config={"incl_context": False}, - init_config={"kb_loader": {"@misc": "myAdamKB.v1"}}, ) - nlp.initialize() + entity_linker.set_kb(create_kb) # With the default get_candidates function, matching is case-sensitive text = "Douglas and douglas are not the same." doc = nlp(text) @@ -236,18 +222,15 @@ def test_el_pipe_configuration(nlp): return get_lowercased_candidates # replace the pipe with a new one with with a different candidate generator - nlp.replace_pipe( + entity_linker = nlp.replace_pipe( "entity_linker", "entity_linker", config={ "incl_context": False, "get_candidates": {"@misc": "spacy.LowercaseCandidateGenerator.v1"}, }, - init_config={ - "kb_loader": {"@misc": "myAdamKB.v1"}, - }, ) - nlp.initialize() + entity_linker.set_kb(create_kb) doc = nlp(text) assert doc[0].ent_kb_id_ == "Q2" assert doc[1].ent_kb_id_ == "" @@ -339,19 +322,15 @@ def test_preserving_links_asdoc(nlp): """Test that Span.as_doc preserves the existing entity links""" vector_length = 1 - @registry.misc.register("myLocationsKB.v1") - def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]: - def create_kb(vocab): - mykb = KnowledgeBase(vocab, entity_vector_length=vector_length) - # adding entities - mykb.add_entity(entity="Q1", freq=19, entity_vector=[1]) - mykb.add_entity(entity="Q2", freq=8, entity_vector=[1]) - # adding aliases - mykb.add_alias(alias="Boston", entities=["Q1"], probabilities=[0.7]) - mykb.add_alias(alias="Denver", entities=["Q2"], probabilities=[0.6]) - return mykb - - return create_kb + def create_kb(vocab): + mykb = KnowledgeBase(vocab, entity_vector_length=vector_length) + # adding entities + mykb.add_entity(entity="Q1", freq=19, entity_vector=[1]) + mykb.add_entity(entity="Q2", freq=8, entity_vector=[1]) + # adding aliases + mykb.add_alias(alias="Boston", entities=["Q1"], probabilities=[0.7]) + mykb.add_alias(alias="Denver", entities=["Q2"], probabilities=[0.6]) + return mykb # set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained) nlp.add_pipe("sentencizer") @@ -362,8 +341,8 @@ def test_preserving_links_asdoc(nlp): ruler = nlp.add_pipe("entity_ruler") ruler.add_patterns(patterns) config = {"incl_prior": False} - init_config = {"kb_loader": {"@misc": "myLocationsKB.v1"}} - entity_linker = nlp.add_pipe("entity_linker", config=config, init_config=init_config, last=True) + entity_linker = nlp.add_pipe("entity_linker", config=config, last=True) + entity_linker.set_kb(create_kb) nlp.initialize() assert entity_linker.model.get_dim("nO") == vector_length @@ -441,30 +420,26 @@ def test_overfitting_IO(): doc = nlp(text) train_examples.append(Example.from_dict(doc, annotation)) - @registry.misc.register("myOverfittingKB.v1") - def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]: - def create_kb(vocab): - # create artificial KB - assign same prior weight to the two russ cochran's - # Q2146908 (Russ Cochran): American golfer - # Q7381115 (Russ Cochran): publisher - mykb = KnowledgeBase(vocab, entity_vector_length=vector_length) - mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3]) - mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7]) - mykb.add_alias( - alias="Russ Cochran", - entities=["Q2146908", "Q7381115"], - probabilities=[0.5, 0.5], - ) - return mykb - - return create_kb + def create_kb(vocab): + # create artificial KB - assign same prior weight to the two russ cochran's + # Q2146908 (Russ Cochran): American golfer + # Q7381115 (Russ Cochran): publisher + mykb = KnowledgeBase(vocab, entity_vector_length=vector_length) + mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3]) + mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7]) + mykb.add_alias( + alias="Russ Cochran", + entities=["Q2146908", "Q7381115"], + probabilities=[0.5, 0.5], + ) + return mykb # Create the Entity Linker component and add it to the pipeline entity_linker = nlp.add_pipe( "entity_linker", - init_config={"kb_loader": {"@misc": "myOverfittingKB.v1"}}, last=True, ) + entity_linker.set_kb(create_kb) # train the NEL pipe optimizer = nlp.initialize(get_examples=lambda: train_examples) diff --git a/spacy/tests/regression/test_issue5230.py b/spacy/tests/regression/test_issue5230.py index aa4cc9be1..9fda413a3 100644 --- a/spacy/tests/regression/test_issue5230.py +++ b/spacy/tests/regression/test_issue5230.py @@ -71,17 +71,13 @@ def tagger(): def entity_linker(): nlp = Language() - @registry.misc.register("TestIssue5230KB.v1") - def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]: - def create_kb(vocab): - kb = KnowledgeBase(vocab, entity_vector_length=1) - kb.add_entity("test", 0.0, zeros((1, 1), dtype="f")) - return kb + def create_kb(vocab): + kb = KnowledgeBase(vocab, entity_vector_length=1) + kb.add_entity("test", 0.0, zeros((1, 1), dtype="f")) + return kb - return create_kb - - init_config = {"kb_loader": {"@misc": "TestIssue5230KB.v1"}} - entity_linker = nlp.add_pipe("entity_linker", init_config=init_config) + entity_linker = nlp.add_pipe("entity_linker") + entity_linker.set_kb(create_kb) # need to add model for two reasons: # 1. no model leads to error in serialization, # 2. the affected line is the one for model serialization diff --git a/spacy/tests/serialize/test_serialize_kb.py b/spacy/tests/serialize/test_serialize_kb.py index 84e7c8ec2..352c335ea 100644 --- a/spacy/tests/serialize/test_serialize_kb.py +++ b/spacy/tests/serialize/test_serialize_kb.py @@ -1,9 +1,9 @@ from typing import Callable from spacy import util -from spacy.lang.en import English -from spacy.util import ensure_path, registry +from spacy.util import ensure_path, registry, load_model_from_config from spacy.kb import KnowledgeBase +from thinc.api import Config from ..util import make_tempdir from numpy import zeros @@ -81,6 +81,28 @@ def _check_kb(kb): def test_serialize_subclassed_kb(): """Check that IO of a custom KB works fine as part of an EL pipe.""" + config_string = """ + [nlp] + lang = "en" + pipeline = ["entity_linker"] + + [components] + + [components.entity_linker] + factory = "entity_linker" + + [initialize] + + [initialize.components] + + [initialize.components.entity_linker] + + [initialize.components.entity_linker.kb_loader] + @misc = "spacy.CustomKB.v1" + entity_vector_length = 342 + custom_field = 666 + """ + class SubKnowledgeBase(KnowledgeBase): def __init__(self, vocab, entity_vector_length, custom_field): super().__init__(vocab, entity_vector_length) @@ -101,16 +123,11 @@ def test_serialize_subclassed_kb(): return custom_kb_factory - nlp = English() - config = { - "kb_loader": { - "@misc": "spacy.CustomKB.v1", - "entity_vector_length": 342, - "custom_field": 666, - } - } - entity_linker = nlp.add_pipe("entity_linker", init_config=config) + config = Config().from_str(config_string) + nlp = load_model_from_config(config, auto_fill=True) nlp.initialize() + + entity_linker = nlp.get_pipe("entity_linker") assert type(entity_linker.kb) == SubKnowledgeBase assert entity_linker.kb.entity_vector_length == 342 assert entity_linker.kb.custom_field == 666 diff --git a/website/docs/api/language.md b/website/docs/api/language.md index 6257199c9..51e9a5e10 100644 --- a/website/docs/api/language.md +++ b/website/docs/api/language.md @@ -524,7 +524,7 @@ Get a pipeline component for a given component name. ## Language.replace_pipe {#replace_pipe tag="method" new="2"} -Replace a component in the pipeline. +Replace a component in the pipeline and return the new component. @@ -538,7 +538,7 @@ and instead expects the **name of a component factory** registered using > #### Example > > ```python -> nlp.replace_pipe("parser", my_custom_parser) +> new_parser = nlp.replace_pipe("parser", "my_custom_parser") > ``` | Name | Description | @@ -548,6 +548,7 @@ and instead expects the **name of a component factory** registered using | _keyword-only_ | | | `config` 3 | Optional config parameters to use for the new component. Will be merged with the `default_config` specified by the component factory. ~~Optional[Dict[str, Any]]~~ | | `validate` 3 | Whether to validate the component config and arguments against the types expected by the factory. Defaults to `True`. ~~bool~~ | +| **RETURNS** | The new pipeline component. ~~Callable[[Doc], Doc]~~ | ## Language.rename_pipe {#rename_pipe tag="method" new="2"}