add init_config to nlp.create_pipe

This commit is contained in:
svlandeg 2020-10-07 14:58:16 +02:00
parent 33c2d4af16
commit 6b8bdb2d39
2 changed files with 36 additions and 9 deletions

View File

@ -600,6 +600,7 @@ class Language:
*, *,
config: Optional[Dict[str, Any]] = SimpleFrozenDict(), config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
raw_config: Optional[Config] = None, raw_config: Optional[Config] = None,
init_config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
validate: bool = True, validate: bool = True,
) -> Callable[[Doc], Doc]: ) -> Callable[[Doc], Doc]:
"""Create a pipeline component. Mostly used internally. To create and """Create a pipeline component. Mostly used internally. To create and
@ -611,6 +612,9 @@ class Language:
config (Optional[Dict[str, Any]]): Config parameters to use for this config (Optional[Dict[str, Any]]): Config parameters to use for this
component. Will be merged with default config, if available. component. Will be merged with default config, if available.
raw_config (Optional[Config]): Internals: the non-interpolated config. raw_config (Optional[Config]): Internals: the non-interpolated config.
init_config (Optional[Dict[str, Any]]): Config parameters to use to
initialize this component. Will be used to update the internal
'initialize' config.
validate (bool): Whether to validate the component config against the validate (bool): Whether to validate the component config against the
arguments and types expected by the factory. arguments and types expected by the factory.
RETURNS (Callable[[Doc], Doc]): The pipeline component. RETURNS (Callable[[Doc], Doc]): The pipeline component.
@ -621,8 +625,13 @@ class Language:
if not isinstance(config, dict): if not isinstance(config, dict):
err = Errors.E962.format(style="config", name=name, cfg_type=type(config)) err = Errors.E962.format(style="config", name=name, cfg_type=type(config))
raise ValueError(err) raise ValueError(err)
if not isinstance(init_config, dict):
err = Errors.E962.format(style="init_config", name=name, cfg_type=type(init_config))
raise ValueError(err)
if not srsly.is_json_serializable(config): if not srsly.is_json_serializable(config):
raise ValueError(Errors.E961.format(config=config)) raise ValueError(Errors.E961.format(config=config))
if not srsly.is_json_serializable(init_config):
raise ValueError(Errors.E961.format(config=init_config))
if not self.has_factory(factory_name): if not self.has_factory(factory_name):
err = Errors.E002.format( err = Errors.E002.format(
name=factory_name, name=factory_name,
@ -634,6 +643,8 @@ class Language:
raise ValueError(err) raise ValueError(err)
pipe_meta = self.get_factory_meta(factory_name) pipe_meta = self.get_factory_meta(factory_name)
config = config or {} config = config or {}
if init_config:
self._config["initialize"]["components"][name] = init_config
# This is unideal, but the alternative would mean you always need to # This is unideal, but the alternative would mean you always need to
# specify the full config settings, which is not really viable. # specify the full config settings, which is not really viable.
if pipe_meta.default_config: if pipe_meta.default_config:
@ -708,6 +719,7 @@ class Language:
source: Optional["Language"] = None, source: Optional["Language"] = None,
config: Optional[Dict[str, Any]] = SimpleFrozenDict(), config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
raw_config: Optional[Config] = None, raw_config: Optional[Config] = None,
init_config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
validate: bool = True, validate: bool = True,
) -> Callable[[Doc], Doc]: ) -> Callable[[Doc], Doc]:
"""Add a component to the processing pipeline. Valid components are """Add a component to the processing pipeline. Valid components are
@ -730,6 +742,9 @@ class Language:
config (Optional[Dict[str, Any]]): Config parameters to use for this config (Optional[Dict[str, Any]]): Config parameters to use for this
component. Will be merged with default config, if available. component. Will be merged with default config, if available.
raw_config (Optional[Config]): Internals: the non-interpolated config. raw_config (Optional[Config]): Internals: the non-interpolated config.
init_config (Optional[Dict[str, Any]]): Config parameters to use to
initialize this component. Will be used to update the internal
'initialize' config.
validate (bool): Whether to validate the component config against the validate (bool): Whether to validate the component config against the
arguments and types expected by the factory. arguments and types expected by the factory.
RETURNS (Callable[[Doc], Doc]): The pipeline component. RETURNS (Callable[[Doc], Doc]): The pipeline component.
@ -763,6 +778,7 @@ class Language:
name=name, name=name,
config=config, config=config,
raw_config=raw_config, raw_config=raw_config,
init_config=init_config,
validate=validate, validate=validate,
) )
pipe_index = self._get_pipe_index(before, after, first, last) pipe_index = self._get_pipe_index(before, after, first, last)
@ -842,6 +858,7 @@ class Language:
factory_name: str, factory_name: str,
*, *,
config: Dict[str, Any] = SimpleFrozenDict(), config: Dict[str, Any] = SimpleFrozenDict(),
init_config: Dict[str, Any] = SimpleFrozenDict(),
validate: bool = True, validate: bool = True,
) -> None: ) -> None:
"""Replace a component in the pipeline. """Replace a component in the pipeline.
@ -850,6 +867,9 @@ class Language:
factory_name (str): Factory name of replacement component. factory_name (str): Factory name of replacement component.
config (Optional[Dict[str, Any]]): Config parameters to use for this config (Optional[Dict[str, Any]]): Config parameters to use for this
component. Will be merged with default config, if available. component. Will be merged with default config, if available.
init_config (Optional[Dict[str, Any]]): Config parameters to use to
initialize this component. Will be used to update the internal
'initialize' config.
validate (bool): Whether to validate the component config against the validate (bool): Whether to validate the component config against the
arguments and types expected by the factory. arguments and types expected by the factory.
@ -866,13 +886,14 @@ class Language:
self.remove_pipe(name) self.remove_pipe(name)
if not len(self._components) or pipe_index == len(self._components): if not len(self._components) or pipe_index == len(self._components):
# we have no components to insert before/after, or we're replacing the last component # we have no components to insert before/after, or we're replacing the last component
self.add_pipe(factory_name, name=name, config=config, validate=validate) self.add_pipe(factory_name, name=name, config=config, init_config=init_config, validate=validate)
else: else:
self.add_pipe( self.add_pipe(
factory_name, factory_name,
name=name, name=name,
before=pipe_index, before=pipe_index,
config=config, config=config,
init_config=init_config,
validate=validate, validate=validate,
) )

View File

@ -110,7 +110,7 @@ def test_kb_invalid_entity_vector(nlp):
def test_kb_default(nlp): def test_kb_default(nlp):
"""Test that the default (empty) KB is loaded when not providing a config""" """Test that the default (empty) KB is loaded upon construction"""
entity_linker = nlp.add_pipe("entity_linker", config={}) entity_linker = nlp.add_pipe("entity_linker", config={})
assert len(entity_linker.kb) == 0 assert len(entity_linker.kb) == 0
assert entity_linker.kb.get_size_entities() == 0 assert entity_linker.kb.get_size_entities() == 0
@ -122,7 +122,7 @@ def test_kb_default(nlp):
def test_kb_custom_length(nlp): def test_kb_custom_length(nlp):
"""Test that the default (empty) KB can be configured with a custom entity length""" """Test that the default (empty) KB can be configured with a custom entity length"""
entity_linker = nlp.add_pipe( entity_linker = nlp.add_pipe(
"entity_linker", config={"kb_loader": {"entity_vector_length": 35}} "entity_linker", config={"entity_vector_length": 35}
) )
assert len(entity_linker.kb) == 0 assert len(entity_linker.kb) == 0
assert entity_linker.kb.get_size_entities() == 0 assert entity_linker.kb.get_size_entities() == 0
@ -140,7 +140,7 @@ def test_kb_undefined(nlp):
def test_kb_empty(nlp): def test_kb_empty(nlp):
"""Test that the EL can't train with an empty KB""" """Test that the EL can't train with an empty KB"""
config = {"kb_loader": {"@misc": "spacy.EmptyKB.v1", "entity_vector_length": 342}} config = {"kb_loader": {"@misc": "spacy.EmptyKB.v1", "entity_vector_length": 342}}
entity_linker = nlp.add_pipe("entity_linker", config=config) entity_linker = nlp.add_pipe("entity_linker", init_config=config)
assert len(entity_linker.kb) == 0 assert len(entity_linker.kb) == 0
with pytest.raises(ValueError): with pytest.raises(ValueError):
entity_linker.initialize(lambda: []) entity_linker.initialize(lambda: [])
@ -217,8 +217,10 @@ def test_el_pipe_configuration(nlp):
# run an EL pipe without a trained context encoder, to check the candidate generation step only # run an EL pipe without a trained context encoder, to check the candidate generation step only
nlp.add_pipe( nlp.add_pipe(
"entity_linker", "entity_linker",
config={"kb_loader": {"@misc": "myAdamKB.v1"}, "incl_context": False}, config={"incl_context": False},
init_config={"kb_loader": {"@misc": "myAdamKB.v1"}},
) )
nlp.initialize()
# With the default get_candidates function, matching is case-sensitive # With the default get_candidates function, matching is case-sensitive
text = "Douglas and douglas are not the same." text = "Douglas and douglas are not the same."
doc = nlp(text) doc = nlp(text)
@ -238,11 +240,14 @@ def test_el_pipe_configuration(nlp):
"entity_linker", "entity_linker",
"entity_linker", "entity_linker",
config={ config={
"kb_loader": {"@misc": "myAdamKB.v1"},
"incl_context": False, "incl_context": False,
"get_candidates": {"@misc": "spacy.LowercaseCandidateGenerator.v1"}, "get_candidates": {"@misc": "spacy.LowercaseCandidateGenerator.v1"},
}, },
init_config={
"kb_loader": {"@misc": "myAdamKB.v1"},
},
) )
nlp.initialize()
doc = nlp(text) doc = nlp(text)
assert doc[0].ent_kb_id_ == "Q2" assert doc[0].ent_kb_id_ == "Q2"
assert doc[1].ent_kb_id_ == "" assert doc[1].ent_kb_id_ == ""
@ -356,8 +361,9 @@ def test_preserving_links_asdoc(nlp):
] ]
ruler = nlp.add_pipe("entity_ruler") ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns(patterns) ruler.add_patterns(patterns)
el_config = {"kb_loader": {"@misc": "myLocationsKB.v1"}, "incl_prior": False} config = {"incl_prior": False}
entity_linker = nlp.add_pipe("entity_linker", config=el_config, last=True) init_config = {"kb_loader": {"@misc": "myLocationsKB.v1"}}
entity_linker = nlp.add_pipe("entity_linker", config=config, init_config=init_config, last=True)
nlp.initialize() nlp.initialize()
assert entity_linker.model.get_dim("nO") == vector_length assert entity_linker.model.get_dim("nO") == vector_length
@ -456,7 +462,7 @@ def test_overfitting_IO():
# Create the Entity Linker component and add it to the pipeline # Create the Entity Linker component and add it to the pipeline
entity_linker = nlp.add_pipe( entity_linker = nlp.add_pipe(
"entity_linker", "entity_linker",
config={"kb_loader": {"@misc": "myOverfittingKB.v1"}}, init_config={"kb_loader": {"@misc": "myOverfittingKB.v1"}},
last=True, last=True,
) )