From 4332d12ce28807f1102f5dc81a285c959ce72fad Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sat, 10 Jun 2023 16:55:52 +0200 Subject: [PATCH] Support adding pipeline component by instance --- spacy/__init__.py | 2 ++ spacy/tests/test_language.py | 33 +++++++++++++++++++++++++++++++++ spacy/util.py | 21 ++++++++++++++++++++- 3 files changed, 55 insertions(+), 1 deletion(-) diff --git a/spacy/__init__.py b/spacy/__init__.py index c3568bc5c..995f965ae 100644 --- a/spacy/__init__.py +++ b/spacy/__init__.py @@ -35,6 +35,7 @@ def load( enable: Union[str, Iterable[str]] = util._DEFAULT_EMPTY_PIPES, exclude: Union[str, Iterable[str]] = util._DEFAULT_EMPTY_PIPES, config: Union[Dict[str, Any], Config] = util.SimpleFrozenDict(), + pipe_instances: Dict[str, Any] = util.SimpleFrozenDict(), ) -> Language: """Load a spaCy model from an installed package or a local path. @@ -58,6 +59,7 @@ def load( enable=enable, exclude=exclude, config=config, + pipe_instances=pipe_instances, ) diff --git a/spacy/tests/test_language.py b/spacy/tests/test_language.py index 236856dad..02e58d0a0 100644 --- a/spacy/tests/test_language.py +++ b/spacy/tests/test_language.py @@ -799,3 +799,36 @@ def test_component_return(): nlp.add_pipe("test_component_bad_pipe") with pytest.raises(ValueError, match="instead of a Doc"): nlp("text") + + +@pytest.mark.parametrize("components,kwargs,position", [ + (["t1", "t2"], {"before": "t1"}, 0), + (["t1", "t2"], {"after": "t1"}, 1), + (["t1", "t2"], {"after": "t1"}, 1), + (["t1", "t2"], {"first": True}, 0), + (["t1", "t2"], {"last": True}, 2), + (["t1", "t2"], {"last": False}, 2), + (["t1", "t2"], {"first": False}, ValueError), +]) +def test_add_pipe_instance(components, kwargs, position): + nlp = Language() + for name in components: + nlp.add_pipe("textcat", name=name) + pipe_names = list(nlp.pipe_names) + if isinstance(position, int): + result = nlp.add_pipe_instance(evil_component, name="new_component", **kwargs) + assert result is evil_component + pipe_names.insert(position, "new_component") + assert nlp.pipe_names == pipe_names + else: + with pytest.raises(ValueError): + result = nlp.add_pipe_instance(evil_component, name="new_component", **kwargs) + + +def test_add_pipe_instance_to_bytes(): + nlp = Language() + nlp.add_pipe("textcat", name="t1") + nlp.add_pipe("textcat", name="t2") + nlp.add_pipe_instance(evil_component, name="new_component") + b = nlp.to_bytes() + diff --git a/spacy/util.py b/spacy/util.py index 8cc89217d..fce3f73be 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -415,6 +415,7 @@ def load_model( enable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES, exclude: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES, config: Union[Dict[str, Any], Config] = SimpleFrozenDict(), + pipe_instances: Dict[str, Any] = SimpleFrozenDict() ) -> "Language": """Load a model from a package or data path. @@ -426,6 +427,9 @@ def load_model( exclude (Union[str, Iterable[str]]): Name(s) of pipeline component(s) to exclude. config (Dict[str, Any] / Config): Config overrides as nested dict or dict keyed by section values in dot notation. + pipe_instances (Dict[str, Any]): Dictionary of components + to be added to the pipeline directly (not created from + config) RETURNS (Language): The loaded nlp object. """ kwargs = { @@ -434,6 +438,7 @@ def load_model( "enable": enable, "exclude": exclude, "config": config, + "pipe_instances": pipe_instances } if isinstance(name, str): # name or string path if name.startswith("blank:"): # shortcut for blank model @@ -457,6 +462,7 @@ def load_model_from_package( enable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES, exclude: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES, config: Union[Dict[str, Any], Config] = SimpleFrozenDict(), + pipe_instances: Dict[str, Any] = SimpleFrozenDict() ) -> "Language": """Load a model from an installed package. @@ -472,10 +478,13 @@ def load_model_from_package( components won't be loaded. config (Dict[str, Any] / Config): Config overrides as nested dict or dict keyed by section values in dot notation. + pipe_instances (Dict[str, Any]): Dictionary of components + to be added to the pipeline directly (not created from + config) RETURNS (Language): The loaded nlp object. """ cls = importlib.import_module(name) - return cls.load(vocab=vocab, disable=disable, enable=enable, exclude=exclude, config=config) # type: ignore[attr-defined] + return cls.load(vocab=vocab, disable=disable, enable=enable, exclude=exclude, config=config, pipe_instances=pipe_instances) # type: ignore[attr-defined] def load_model_from_path( @@ -487,6 +496,7 @@ def load_model_from_path( enable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES, exclude: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES, config: Union[Dict[str, Any], Config] = SimpleFrozenDict(), + pipe_instances: Dict[str, Any] = SimpleFrozenDict() ) -> "Language": """Load a model from a data directory path. Creates Language class with pipeline from config.cfg and then calls from_disk() with path. @@ -504,6 +514,9 @@ def load_model_from_path( components won't be loaded. config (Dict[str, Any] / Config): Config overrides as nested dict or dict keyed by section values in dot notation. + pipe_instances (Dict[str, Any]): Dictionary of components + to be added to the pipeline directly (not created from + config) RETURNS (Language): The loaded nlp object. """ if not model_path.exists(): @@ -520,6 +533,7 @@ def load_model_from_path( enable=enable, exclude=exclude, meta=meta, + pipe_instances=pipe_instances ) return nlp.from_disk(model_path, exclude=exclude, overrides=overrides) @@ -534,6 +548,7 @@ def load_model_from_config( exclude: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES, auto_fill: bool = False, validate: bool = True, + pipe_instances: Dict[str, Any] = SimpleFrozenDict() ) -> "Language": """Create an nlp object from a config. Expects the full config file including a section "nlp" containing the settings for the nlp object. @@ -551,6 +566,9 @@ def load_model_from_config( components won't be loaded. auto_fill (bool): Whether to auto-fill config with missing defaults. validate (bool): Whether to show config validation errors. + pipe_instances (Dict[str, Any]): Dictionary of components + to be added to the pipeline directly (not created from + config) RETURNS (Language): The loaded nlp object. """ if "nlp" not in config: @@ -570,6 +588,7 @@ def load_model_from_config( auto_fill=auto_fill, validate=validate, meta=meta, + pipe_instances=pipe_instances ) return nlp