mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-26 16:42:55 +03:00
add init_config to nlp.create_pipe
This commit is contained in:
parent
33c2d4af16
commit
6b8bdb2d39
|
@ -600,6 +600,7 @@ class Language:
|
||||||
*,
|
*,
|
||||||
config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
||||||
raw_config: Optional[Config] = None,
|
raw_config: Optional[Config] = None,
|
||||||
|
init_config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
||||||
validate: bool = True,
|
validate: bool = True,
|
||||||
) -> Callable[[Doc], Doc]:
|
) -> Callable[[Doc], Doc]:
|
||||||
"""Create a pipeline component. Mostly used internally. To create and
|
"""Create a pipeline component. Mostly used internally. To create and
|
||||||
|
@ -611,6 +612,9 @@ class Language:
|
||||||
config (Optional[Dict[str, Any]]): Config parameters to use for this
|
config (Optional[Dict[str, Any]]): Config parameters to use for this
|
||||||
component. Will be merged with default config, if available.
|
component. Will be merged with default config, if available.
|
||||||
raw_config (Optional[Config]): Internals: the non-interpolated config.
|
raw_config (Optional[Config]): Internals: the non-interpolated config.
|
||||||
|
init_config (Optional[Dict[str, Any]]): Config parameters to use to
|
||||||
|
initialize this component. Will be used to update the internal
|
||||||
|
'initialize' config.
|
||||||
validate (bool): Whether to validate the component config against the
|
validate (bool): Whether to validate the component config against the
|
||||||
arguments and types expected by the factory.
|
arguments and types expected by the factory.
|
||||||
RETURNS (Callable[[Doc], Doc]): The pipeline component.
|
RETURNS (Callable[[Doc], Doc]): The pipeline component.
|
||||||
|
@ -621,8 +625,13 @@ class Language:
|
||||||
if not isinstance(config, dict):
|
if not isinstance(config, dict):
|
||||||
err = Errors.E962.format(style="config", name=name, cfg_type=type(config))
|
err = Errors.E962.format(style="config", name=name, cfg_type=type(config))
|
||||||
raise ValueError(err)
|
raise ValueError(err)
|
||||||
|
if not isinstance(init_config, dict):
|
||||||
|
err = Errors.E962.format(style="init_config", name=name, cfg_type=type(init_config))
|
||||||
|
raise ValueError(err)
|
||||||
if not srsly.is_json_serializable(config):
|
if not srsly.is_json_serializable(config):
|
||||||
raise ValueError(Errors.E961.format(config=config))
|
raise ValueError(Errors.E961.format(config=config))
|
||||||
|
if not srsly.is_json_serializable(init_config):
|
||||||
|
raise ValueError(Errors.E961.format(config=init_config))
|
||||||
if not self.has_factory(factory_name):
|
if not self.has_factory(factory_name):
|
||||||
err = Errors.E002.format(
|
err = Errors.E002.format(
|
||||||
name=factory_name,
|
name=factory_name,
|
||||||
|
@ -634,6 +643,8 @@ class Language:
|
||||||
raise ValueError(err)
|
raise ValueError(err)
|
||||||
pipe_meta = self.get_factory_meta(factory_name)
|
pipe_meta = self.get_factory_meta(factory_name)
|
||||||
config = config or {}
|
config = config or {}
|
||||||
|
if init_config:
|
||||||
|
self._config["initialize"]["components"][name] = init_config
|
||||||
# This is unideal, but the alternative would mean you always need to
|
# This is unideal, but the alternative would mean you always need to
|
||||||
# specify the full config settings, which is not really viable.
|
# specify the full config settings, which is not really viable.
|
||||||
if pipe_meta.default_config:
|
if pipe_meta.default_config:
|
||||||
|
@ -708,6 +719,7 @@ class Language:
|
||||||
source: Optional["Language"] = None,
|
source: Optional["Language"] = None,
|
||||||
config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
||||||
raw_config: Optional[Config] = None,
|
raw_config: Optional[Config] = None,
|
||||||
|
init_config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
||||||
validate: bool = True,
|
validate: bool = True,
|
||||||
) -> Callable[[Doc], Doc]:
|
) -> Callable[[Doc], Doc]:
|
||||||
"""Add a component to the processing pipeline. Valid components are
|
"""Add a component to the processing pipeline. Valid components are
|
||||||
|
@ -730,6 +742,9 @@ class Language:
|
||||||
config (Optional[Dict[str, Any]]): Config parameters to use for this
|
config (Optional[Dict[str, Any]]): Config parameters to use for this
|
||||||
component. Will be merged with default config, if available.
|
component. Will be merged with default config, if available.
|
||||||
raw_config (Optional[Config]): Internals: the non-interpolated config.
|
raw_config (Optional[Config]): Internals: the non-interpolated config.
|
||||||
|
init_config (Optional[Dict[str, Any]]): Config parameters to use to
|
||||||
|
initialize this component. Will be used to update the internal
|
||||||
|
'initialize' config.
|
||||||
validate (bool): Whether to validate the component config against the
|
validate (bool): Whether to validate the component config against the
|
||||||
arguments and types expected by the factory.
|
arguments and types expected by the factory.
|
||||||
RETURNS (Callable[[Doc], Doc]): The pipeline component.
|
RETURNS (Callable[[Doc], Doc]): The pipeline component.
|
||||||
|
@ -763,6 +778,7 @@ class Language:
|
||||||
name=name,
|
name=name,
|
||||||
config=config,
|
config=config,
|
||||||
raw_config=raw_config,
|
raw_config=raw_config,
|
||||||
|
init_config=init_config,
|
||||||
validate=validate,
|
validate=validate,
|
||||||
)
|
)
|
||||||
pipe_index = self._get_pipe_index(before, after, first, last)
|
pipe_index = self._get_pipe_index(before, after, first, last)
|
||||||
|
@ -842,6 +858,7 @@ class Language:
|
||||||
factory_name: str,
|
factory_name: str,
|
||||||
*,
|
*,
|
||||||
config: Dict[str, Any] = SimpleFrozenDict(),
|
config: Dict[str, Any] = SimpleFrozenDict(),
|
||||||
|
init_config: Dict[str, Any] = SimpleFrozenDict(),
|
||||||
validate: bool = True,
|
validate: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Replace a component in the pipeline.
|
"""Replace a component in the pipeline.
|
||||||
|
@ -850,6 +867,9 @@ class Language:
|
||||||
factory_name (str): Factory name of replacement component.
|
factory_name (str): Factory name of replacement component.
|
||||||
config (Optional[Dict[str, Any]]): Config parameters to use for this
|
config (Optional[Dict[str, Any]]): Config parameters to use for this
|
||||||
component. Will be merged with default config, if available.
|
component. Will be merged with default config, if available.
|
||||||
|
init_config (Optional[Dict[str, Any]]): Config parameters to use to
|
||||||
|
initialize this component. Will be used to update the internal
|
||||||
|
'initialize' config.
|
||||||
validate (bool): Whether to validate the component config against the
|
validate (bool): Whether to validate the component config against the
|
||||||
arguments and types expected by the factory.
|
arguments and types expected by the factory.
|
||||||
|
|
||||||
|
@ -866,13 +886,14 @@ class Language:
|
||||||
self.remove_pipe(name)
|
self.remove_pipe(name)
|
||||||
if not len(self._components) or pipe_index == len(self._components):
|
if not len(self._components) or pipe_index == len(self._components):
|
||||||
# we have no components to insert before/after, or we're replacing the last component
|
# we have no components to insert before/after, or we're replacing the last component
|
||||||
self.add_pipe(factory_name, name=name, config=config, validate=validate)
|
self.add_pipe(factory_name, name=name, config=config, init_config=init_config, validate=validate)
|
||||||
else:
|
else:
|
||||||
self.add_pipe(
|
self.add_pipe(
|
||||||
factory_name,
|
factory_name,
|
||||||
name=name,
|
name=name,
|
||||||
before=pipe_index,
|
before=pipe_index,
|
||||||
config=config,
|
config=config,
|
||||||
|
init_config=init_config,
|
||||||
validate=validate,
|
validate=validate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -110,7 +110,7 @@ def test_kb_invalid_entity_vector(nlp):
|
||||||
|
|
||||||
|
|
||||||
def test_kb_default(nlp):
|
def test_kb_default(nlp):
|
||||||
"""Test that the default (empty) KB is loaded when not providing a config"""
|
"""Test that the default (empty) KB is loaded upon construction"""
|
||||||
entity_linker = nlp.add_pipe("entity_linker", config={})
|
entity_linker = nlp.add_pipe("entity_linker", config={})
|
||||||
assert len(entity_linker.kb) == 0
|
assert len(entity_linker.kb) == 0
|
||||||
assert entity_linker.kb.get_size_entities() == 0
|
assert entity_linker.kb.get_size_entities() == 0
|
||||||
|
@ -122,7 +122,7 @@ def test_kb_default(nlp):
|
||||||
def test_kb_custom_length(nlp):
|
def test_kb_custom_length(nlp):
|
||||||
"""Test that the default (empty) KB can be configured with a custom entity length"""
|
"""Test that the default (empty) KB can be configured with a custom entity length"""
|
||||||
entity_linker = nlp.add_pipe(
|
entity_linker = nlp.add_pipe(
|
||||||
"entity_linker", config={"kb_loader": {"entity_vector_length": 35}}
|
"entity_linker", config={"entity_vector_length": 35}
|
||||||
)
|
)
|
||||||
assert len(entity_linker.kb) == 0
|
assert len(entity_linker.kb) == 0
|
||||||
assert entity_linker.kb.get_size_entities() == 0
|
assert entity_linker.kb.get_size_entities() == 0
|
||||||
|
@ -140,7 +140,7 @@ def test_kb_undefined(nlp):
|
||||||
def test_kb_empty(nlp):
|
def test_kb_empty(nlp):
|
||||||
"""Test that the EL can't train with an empty KB"""
|
"""Test that the EL can't train with an empty KB"""
|
||||||
config = {"kb_loader": {"@misc": "spacy.EmptyKB.v1", "entity_vector_length": 342}}
|
config = {"kb_loader": {"@misc": "spacy.EmptyKB.v1", "entity_vector_length": 342}}
|
||||||
entity_linker = nlp.add_pipe("entity_linker", config=config)
|
entity_linker = nlp.add_pipe("entity_linker", init_config=config)
|
||||||
assert len(entity_linker.kb) == 0
|
assert len(entity_linker.kb) == 0
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
entity_linker.initialize(lambda: [])
|
entity_linker.initialize(lambda: [])
|
||||||
|
@ -217,8 +217,10 @@ def test_el_pipe_configuration(nlp):
|
||||||
# run an EL pipe without a trained context encoder, to check the candidate generation step only
|
# run an EL pipe without a trained context encoder, to check the candidate generation step only
|
||||||
nlp.add_pipe(
|
nlp.add_pipe(
|
||||||
"entity_linker",
|
"entity_linker",
|
||||||
config={"kb_loader": {"@misc": "myAdamKB.v1"}, "incl_context": False},
|
config={"incl_context": False},
|
||||||
|
init_config={"kb_loader": {"@misc": "myAdamKB.v1"}},
|
||||||
)
|
)
|
||||||
|
nlp.initialize()
|
||||||
# With the default get_candidates function, matching is case-sensitive
|
# With the default get_candidates function, matching is case-sensitive
|
||||||
text = "Douglas and douglas are not the same."
|
text = "Douglas and douglas are not the same."
|
||||||
doc = nlp(text)
|
doc = nlp(text)
|
||||||
|
@ -238,11 +240,14 @@ def test_el_pipe_configuration(nlp):
|
||||||
"entity_linker",
|
"entity_linker",
|
||||||
"entity_linker",
|
"entity_linker",
|
||||||
config={
|
config={
|
||||||
"kb_loader": {"@misc": "myAdamKB.v1"},
|
|
||||||
"incl_context": False,
|
"incl_context": False,
|
||||||
"get_candidates": {"@misc": "spacy.LowercaseCandidateGenerator.v1"},
|
"get_candidates": {"@misc": "spacy.LowercaseCandidateGenerator.v1"},
|
||||||
},
|
},
|
||||||
|
init_config={
|
||||||
|
"kb_loader": {"@misc": "myAdamKB.v1"},
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
nlp.initialize()
|
||||||
doc = nlp(text)
|
doc = nlp(text)
|
||||||
assert doc[0].ent_kb_id_ == "Q2"
|
assert doc[0].ent_kb_id_ == "Q2"
|
||||||
assert doc[1].ent_kb_id_ == ""
|
assert doc[1].ent_kb_id_ == ""
|
||||||
|
@ -356,8 +361,9 @@ def test_preserving_links_asdoc(nlp):
|
||||||
]
|
]
|
||||||
ruler = nlp.add_pipe("entity_ruler")
|
ruler = nlp.add_pipe("entity_ruler")
|
||||||
ruler.add_patterns(patterns)
|
ruler.add_patterns(patterns)
|
||||||
el_config = {"kb_loader": {"@misc": "myLocationsKB.v1"}, "incl_prior": False}
|
config = {"incl_prior": False}
|
||||||
entity_linker = nlp.add_pipe("entity_linker", config=el_config, last=True)
|
init_config = {"kb_loader": {"@misc": "myLocationsKB.v1"}}
|
||||||
|
entity_linker = nlp.add_pipe("entity_linker", config=config, init_config=init_config, last=True)
|
||||||
nlp.initialize()
|
nlp.initialize()
|
||||||
assert entity_linker.model.get_dim("nO") == vector_length
|
assert entity_linker.model.get_dim("nO") == vector_length
|
||||||
|
|
||||||
|
@ -456,7 +462,7 @@ def test_overfitting_IO():
|
||||||
# Create the Entity Linker component and add it to the pipeline
|
# Create the Entity Linker component and add it to the pipeline
|
||||||
entity_linker = nlp.add_pipe(
|
entity_linker = nlp.add_pipe(
|
||||||
"entity_linker",
|
"entity_linker",
|
||||||
config={"kb_loader": {"@misc": "myOverfittingKB.v1"}},
|
init_config={"kb_loader": {"@misc": "myOverfittingKB.v1"}},
|
||||||
last=True,
|
last=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user