From aa0d747739bf361844d38311be6834b19d6bf5ca Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sat, 10 Jun 2023 16:55:13 +0200 Subject: [PATCH] Support adding pipeline component by instance --- spacy/language.py | 115 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 112 insertions(+), 3 deletions(-) diff --git a/spacy/language.py b/spacy/language.py index 9fdcf6328..7e52c60c3 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1,6 +1,6 @@ from typing import Iterator, Optional, Any, Dict, Callable, Iterable -from typing import Union, Tuple, List, Set, Pattern, Sequence -from typing import NoReturn, TYPE_CHECKING, TypeVar, cast, overload +from typing import Union, Tuple, List, Set, Pattern, Sequence, overload +from typing import NoReturn, TYPE_CHECKING, TypeVar, cast from dataclasses import dataclass import random @@ -52,6 +52,9 @@ DEFAULT_CONFIG = util.load_config(DEFAULT_CONFIG_PATH) # This is the base config for the [pretraining] block and currently not included # in the main config and only added via the 'init fill-config' command DEFAULT_CONFIG_PRETRAIN_PATH = Path(__file__).parent / "default_config_pretraining.cfg" +# Factory name indicating that the component wasn't constructed by a factory, +# and was instead passed by instance +INSTANCE_FACTORY_NAME = "__added_by_instance__" # Type variable for contexts piped with documents _AnyContext = TypeVar("_AnyContext") @@ -743,6 +746,9 @@ class Language: """Add a component to the processing pipeline. Valid components are callables that take a `Doc` object, modify it and return it. Only one of before/after/first/last can be set. Default behaviour is "last". + Components can be added either by factory name or by instance. If + an instance is supplied and you serialize the pipeline, you'll need + to also pass an instance into spacy.load() to construct the pipeline. factory_name (str): Name of the component factory. name (str): Name of pipeline component. Overwrites existing @@ -790,11 +796,58 @@ class Language: raw_config=raw_config, validate=validate, ) - pipe_index = self._get_pipe_index(before, after, first, last) self._pipe_meta[name] = self.get_factory_meta(factory_name) + pipe_index = self._get_pipe_index(before, after, first, last) self._components.insert(pipe_index, (name, pipe_component)) return pipe_component + def add_pipe_instance(self, component: PipeCallable, + /, name: Optional[str] = None, + *, + before: Optional[Union[str, int]] = None, + after: Optional[Union[str, int]] = None, + first: Optional[bool] = None, + last: Optional[bool] = None, + ) -> PipeCallable: + """Add a component instance to the processing pipeline. Valid components + are callables that take a `Doc` object, modify it and return it. Only one + of before/after/first/last can be set. Default behaviour is "last". + + A limitation of this method is that spaCy will not know how to reconstruct + your pipeline after you save it out (unlike the 'Language.add_pipe()' method, + where you provide a config and let spaCy construct the instance). See 'spacy.load' + for details of how to load back a pipeline with components added by instance. + + pipe_instance (Callable[[Doc], Doc]): The component to add. + name (str): Name of pipeline component. Overwrites existing + component.name attribute if available. If no name is set and + the component exposes no name attribute, component.__name__ is + used. An error is raised if a name already exists in the pipeline. + before (Union[str, int]): Name or index of the component to insert new + component directly before. + after (Union[str, int]): Name or index of the component to insert new + component directly after. + first (bool): If True, insert component first in the pipeline. + last (bool): If True, insert component last in the pipeline. + RETURNS (Callable[[Doc], Doc]): The pipeline component. + + DOCS: https://spacy.io/api/language#add_pipe_instance + """ + name = name if name is not None else getattr(component, "name") + if name is None: + raise ValueError("TODO error") + if name in self.component_names: + raise ValueError(Errors.E007.format(name=name, opts=self.component_names)) + + # It would be possible to take arguments for the FactoryMeta here, but we'll then have + # a problem on deserialization: where will the data be coming from? + # I think if someone wants that, they should register a component function. + self._pipe_meta[name] = FactoryMeta(INSTANCE_FACTORY_NAME) + self._pipe_configs[name] = Config() + pipe_index = self._get_pipe_index(before, after, first, last) + self._components.insert(pipe_index, (name, component)) + return component + def _get_pipe_index( self, before: Optional[Union[str, int]] = None, @@ -1690,6 +1743,7 @@ class Language: meta: Dict[str, Any] = SimpleFrozenDict(), auto_fill: bool = True, validate: bool = True, + pipe_instances: Dict[str, Any] = SimpleFrozenDict() ) -> "Language": """Create the nlp object from a loaded config. Will set up the tokenizer and language data, add pipeline components etc. If no config is provided, @@ -1765,6 +1819,11 @@ class Language: # Warn about require_gpu usage in jupyter notebook warn_if_jupyter_cupy() + # If we've been passed pipe instances, check whether + # they have a Vocab instance, and if they do, use + # that one. This also performs some additional checks and + # warns if there's a mismatch. + vocab = _get_instantiated_vocab(vocab, pipe_instances) # Note that we don't load vectors here, instead they get loaded explicitly # inside stuff like the spacy train function. If we loaded them here, @@ -1781,6 +1840,11 @@ class Language: interpolated = filled.interpolate() if not filled.is_interpolated else filled pipeline = interpolated.get("components", {}) sourced = util.get_sourced_components(interpolated) + # Check for components that aren't in the pipe_instances dict, aren't disabled, + # and aren't built by factory. + missing_components = _find_missing_components(pipeline, pipe_instances, exclude) + if missing_components: + raise ValueError(Errors.E1052.format(", ",join(missing_components))) # If components are loaded from a source (existing models), we cache # them here so they're only loaded once source_nlps = {} @@ -1790,6 +1854,18 @@ class Language: if pipe_name not in pipeline: opts = ", ".join(pipeline.keys()) raise ValueError(Errors.E956.format(name=pipe_name, opts=opts)) + if pipe_name in pipe_instances: + if pipe_name in exclude: + continue + else: + nlp.add_pipe_instance( + pipe_instances[pipe_name] + ) + # Is it important that we instantiate pipes that + # aren't excluded? It seems like we would want + # the exclude check above. I've left it how it + # is though, in case there's some sort of crazy + # load-bearing side-effects someone is relying on? pipe_cfg = util.copy_config(pipeline[pipe_name]) raw_config = Config(filled["components"][pipe_name]) if pipe_name not in exclude: @@ -2306,3 +2382,36 @@ class _Sender: if self.count >= self.chunk_size: self.count = 0 self.send() + + +def _get_instantiated_vocab(vocab: Union[bool, Vocab], pipe_instances: Dict[str, Any]) -> Union[bool, Vocab]: + vocab_instances = {} + for name, instance in pipe_instances.items(): + if hasattr(instance, "vocab") and isinstance(instance.vocab, Vocab): + vocab_instances[name] = instance.vocab + if not vocab_instances: + return vocab + elif isinstance(vocab, Vocab): + for name, inst_voc in vocab_instances.items(): + if inst_voc is not vocab: + warnings.warn(Warnings.W125.format(name=name)) + return vocab + else: + resolved_vocab = None + for name, inst_voc in vocab_instances.items(): + if resolved_vocab is None: + resolved_vocab = inst_voc + elif inst_voc is not resolved_vocab: + warnings.warn(Warnings.W125.format(name=name)) + # This is supposed to only be for the type checker -- + # it should be unreachable + assert resolved_vocab is not None + return resolved_vocab + + +def _find_missing_components( + pipeline: List[str], + pipe_instances: Dict[str, Any], + exclude: List[str] +) -> List[str]: + return [name for name in pipeline if name not in pipe_instances and name not in exclude]