From 6b8bdb2d390c4d26577754c213170a0190bb2cc5 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Wed, 7 Oct 2020 14:58:16 +0200 Subject: [PATCH] add init_config to nlp.create_pipe --- spacy/language.py | 23 +++++++++++++++++++++- spacy/tests/pipeline/test_entity_linker.py | 22 +++++++++++++-------- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/spacy/language.py b/spacy/language.py index ba244617e..e3b2285fb 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -600,6 +600,7 @@ class Language: *, config: Optional[Dict[str, Any]] = SimpleFrozenDict(), raw_config: Optional[Config] = None, + init_config: Optional[Dict[str, Any]] = SimpleFrozenDict(), validate: bool = True, ) -> Callable[[Doc], Doc]: """Create a pipeline component. Mostly used internally. To create and @@ -611,6 +612,9 @@ class Language: config (Optional[Dict[str, Any]]): Config parameters to use for this component. Will be merged with default config, if available. raw_config (Optional[Config]): Internals: the non-interpolated config. + init_config (Optional[Dict[str, Any]]): Config parameters to use to + initialize this component. Will be used to update the internal + 'initialize' config. validate (bool): Whether to validate the component config against the arguments and types expected by the factory. RETURNS (Callable[[Doc], Doc]): The pipeline component. @@ -621,8 +625,13 @@ class Language: if not isinstance(config, dict): err = Errors.E962.format(style="config", name=name, cfg_type=type(config)) raise ValueError(err) + if not isinstance(init_config, dict): + err = Errors.E962.format(style="init_config", name=name, cfg_type=type(init_config)) + raise ValueError(err) if not srsly.is_json_serializable(config): raise ValueError(Errors.E961.format(config=config)) + if not srsly.is_json_serializable(init_config): + raise ValueError(Errors.E961.format(config=init_config)) if not self.has_factory(factory_name): err = Errors.E002.format( name=factory_name, @@ -634,6 +643,8 @@ class Language: raise ValueError(err) pipe_meta = self.get_factory_meta(factory_name) config = config or {} + if init_config: + self._config["initialize"]["components"][name] = init_config # This is unideal, but the alternative would mean you always need to # specify the full config settings, which is not really viable. if pipe_meta.default_config: @@ -708,6 +719,7 @@ class Language: source: Optional["Language"] = None, config: Optional[Dict[str, Any]] = SimpleFrozenDict(), raw_config: Optional[Config] = None, + init_config: Optional[Dict[str, Any]] = SimpleFrozenDict(), validate: bool = True, ) -> Callable[[Doc], Doc]: """Add a component to the processing pipeline. Valid components are @@ -730,6 +742,9 @@ class Language: config (Optional[Dict[str, Any]]): Config parameters to use for this component. Will be merged with default config, if available. raw_config (Optional[Config]): Internals: the non-interpolated config. + init_config (Optional[Dict[str, Any]]): Config parameters to use to + initialize this component. Will be used to update the internal + 'initialize' config. validate (bool): Whether to validate the component config against the arguments and types expected by the factory. RETURNS (Callable[[Doc], Doc]): The pipeline component. @@ -763,6 +778,7 @@ class Language: name=name, config=config, raw_config=raw_config, + init_config=init_config, validate=validate, ) pipe_index = self._get_pipe_index(before, after, first, last) @@ -842,6 +858,7 @@ class Language: factory_name: str, *, config: Dict[str, Any] = SimpleFrozenDict(), + init_config: Dict[str, Any] = SimpleFrozenDict(), validate: bool = True, ) -> None: """Replace a component in the pipeline. @@ -850,6 +867,9 @@ class Language: factory_name (str): Factory name of replacement component. config (Optional[Dict[str, Any]]): Config parameters to use for this component. Will be merged with default config, if available. + init_config (Optional[Dict[str, Any]]): Config parameters to use to + initialize this component. Will be used to update the internal + 'initialize' config. validate (bool): Whether to validate the component config against the arguments and types expected by the factory. @@ -866,13 +886,14 @@ class Language: self.remove_pipe(name) if not len(self._components) or pipe_index == len(self._components): # we have no components to insert before/after, or we're replacing the last component - self.add_pipe(factory_name, name=name, config=config, validate=validate) + self.add_pipe(factory_name, name=name, config=config, init_config=init_config, validate=validate) else: self.add_pipe( factory_name, name=name, before=pipe_index, config=config, + init_config=init_config, validate=validate, ) diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index 66de54c06..cf9fce2a7 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -110,7 +110,7 @@ def test_kb_invalid_entity_vector(nlp): def test_kb_default(nlp): - """Test that the default (empty) KB is loaded when not providing a config""" + """Test that the default (empty) KB is loaded upon construction""" entity_linker = nlp.add_pipe("entity_linker", config={}) assert len(entity_linker.kb) == 0 assert entity_linker.kb.get_size_entities() == 0 @@ -122,7 +122,7 @@ def test_kb_default(nlp): def test_kb_custom_length(nlp): """Test that the default (empty) KB can be configured with a custom entity length""" entity_linker = nlp.add_pipe( - "entity_linker", config={"kb_loader": {"entity_vector_length": 35}} + "entity_linker", config={"entity_vector_length": 35} ) assert len(entity_linker.kb) == 0 assert entity_linker.kb.get_size_entities() == 0 @@ -140,7 +140,7 @@ def test_kb_undefined(nlp): def test_kb_empty(nlp): """Test that the EL can't train with an empty KB""" config = {"kb_loader": {"@misc": "spacy.EmptyKB.v1", "entity_vector_length": 342}} - entity_linker = nlp.add_pipe("entity_linker", config=config) + entity_linker = nlp.add_pipe("entity_linker", init_config=config) assert len(entity_linker.kb) == 0 with pytest.raises(ValueError): entity_linker.initialize(lambda: []) @@ -217,8 +217,10 @@ def test_el_pipe_configuration(nlp): # run an EL pipe without a trained context encoder, to check the candidate generation step only nlp.add_pipe( "entity_linker", - config={"kb_loader": {"@misc": "myAdamKB.v1"}, "incl_context": False}, + config={"incl_context": False}, + init_config={"kb_loader": {"@misc": "myAdamKB.v1"}}, ) + nlp.initialize() # With the default get_candidates function, matching is case-sensitive text = "Douglas and douglas are not the same." doc = nlp(text) @@ -238,11 +240,14 @@ def test_el_pipe_configuration(nlp): "entity_linker", "entity_linker", config={ - "kb_loader": {"@misc": "myAdamKB.v1"}, "incl_context": False, "get_candidates": {"@misc": "spacy.LowercaseCandidateGenerator.v1"}, }, + init_config={ + "kb_loader": {"@misc": "myAdamKB.v1"}, + }, ) + nlp.initialize() doc = nlp(text) assert doc[0].ent_kb_id_ == "Q2" assert doc[1].ent_kb_id_ == "" @@ -356,8 +361,9 @@ def test_preserving_links_asdoc(nlp): ] ruler = nlp.add_pipe("entity_ruler") ruler.add_patterns(patterns) - el_config = {"kb_loader": {"@misc": "myLocationsKB.v1"}, "incl_prior": False} - entity_linker = nlp.add_pipe("entity_linker", config=el_config, last=True) + config = {"incl_prior": False} + init_config = {"kb_loader": {"@misc": "myLocationsKB.v1"}} + entity_linker = nlp.add_pipe("entity_linker", config=config, init_config=init_config, last=True) nlp.initialize() assert entity_linker.model.get_dim("nO") == vector_length @@ -456,7 +462,7 @@ def test_overfitting_IO(): # Create the Entity Linker component and add it to the pipeline entity_linker = nlp.add_pipe( "entity_linker", - config={"kb_loader": {"@misc": "myOverfittingKB.v1"}}, + init_config={"kb_loader": {"@misc": "myOverfittingKB.v1"}}, last=True, )