Support adding pipeline component by instance

This commit is contained in:
Matthew Honnibal 2023-06-10 16:55:13 +02:00
parent 6f821efaf3
commit aa0d747739

View File

@ -1,6 +1,6 @@
from typing import Iterator, Optional, Any, Dict, Callable, Iterable from typing import Iterator, Optional, Any, Dict, Callable, Iterable
from typing import Union, Tuple, List, Set, Pattern, Sequence from typing import Union, Tuple, List, Set, Pattern, Sequence, overload
from typing import NoReturn, TYPE_CHECKING, TypeVar, cast, overload from typing import NoReturn, TYPE_CHECKING, TypeVar, cast
from dataclasses import dataclass from dataclasses import dataclass
import random import random
@ -52,6 +52,9 @@ DEFAULT_CONFIG = util.load_config(DEFAULT_CONFIG_PATH)
# This is the base config for the [pretraining] block and currently not included # This is the base config for the [pretraining] block and currently not included
# in the main config and only added via the 'init fill-config' command # in the main config and only added via the 'init fill-config' command
DEFAULT_CONFIG_PRETRAIN_PATH = Path(__file__).parent / "default_config_pretraining.cfg" DEFAULT_CONFIG_PRETRAIN_PATH = Path(__file__).parent / "default_config_pretraining.cfg"
# Factory name indicating that the component wasn't constructed by a factory,
# and was instead passed by instance
INSTANCE_FACTORY_NAME = "__added_by_instance__"
# Type variable for contexts piped with documents # Type variable for contexts piped with documents
_AnyContext = TypeVar("_AnyContext") _AnyContext = TypeVar("_AnyContext")
@ -743,6 +746,9 @@ class Language:
"""Add a component to the processing pipeline. Valid components are """Add a component to the processing pipeline. Valid components are
callables that take a `Doc` object, modify it and return it. Only one callables that take a `Doc` object, modify it and return it. Only one
of before/after/first/last can be set. Default behaviour is "last". of before/after/first/last can be set. Default behaviour is "last".
Components can be added either by factory name or by instance. If
an instance is supplied and you serialize the pipeline, you'll need
to also pass an instance into spacy.load() to construct the pipeline.
factory_name (str): Name of the component factory. factory_name (str): Name of the component factory.
name (str): Name of pipeline component. Overwrites existing name (str): Name of pipeline component. Overwrites existing
@ -790,11 +796,58 @@ class Language:
raw_config=raw_config, raw_config=raw_config,
validate=validate, validate=validate,
) )
pipe_index = self._get_pipe_index(before, after, first, last)
self._pipe_meta[name] = self.get_factory_meta(factory_name) self._pipe_meta[name] = self.get_factory_meta(factory_name)
pipe_index = self._get_pipe_index(before, after, first, last)
self._components.insert(pipe_index, (name, pipe_component)) self._components.insert(pipe_index, (name, pipe_component))
return pipe_component return pipe_component
def add_pipe_instance(self, component: PipeCallable,
/, name: Optional[str] = None,
*,
before: Optional[Union[str, int]] = None,
after: Optional[Union[str, int]] = None,
first: Optional[bool] = None,
last: Optional[bool] = None,
) -> PipeCallable:
"""Add a component instance to the processing pipeline. Valid components
are callables that take a `Doc` object, modify it and return it. Only one
of before/after/first/last can be set. Default behaviour is "last".
A limitation of this method is that spaCy will not know how to reconstruct
your pipeline after you save it out (unlike the 'Language.add_pipe()' method,
where you provide a config and let spaCy construct the instance). See 'spacy.load'
for details of how to load back a pipeline with components added by instance.
pipe_instance (Callable[[Doc], Doc]): The component to add.
name (str): Name of pipeline component. Overwrites existing
component.name attribute if available. If no name is set and
the component exposes no name attribute, component.__name__ is
used. An error is raised if a name already exists in the pipeline.
before (Union[str, int]): Name or index of the component to insert new
component directly before.
after (Union[str, int]): Name or index of the component to insert new
component directly after.
first (bool): If True, insert component first in the pipeline.
last (bool): If True, insert component last in the pipeline.
RETURNS (Callable[[Doc], Doc]): The pipeline component.
DOCS: https://spacy.io/api/language#add_pipe_instance
"""
name = name if name is not None else getattr(component, "name")
if name is None:
raise ValueError("TODO error")
if name in self.component_names:
raise ValueError(Errors.E007.format(name=name, opts=self.component_names))
# It would be possible to take arguments for the FactoryMeta here, but we'll then have
# a problem on deserialization: where will the data be coming from?
# I think if someone wants that, they should register a component function.
self._pipe_meta[name] = FactoryMeta(INSTANCE_FACTORY_NAME)
self._pipe_configs[name] = Config()
pipe_index = self._get_pipe_index(before, after, first, last)
self._components.insert(pipe_index, (name, component))
return component
def _get_pipe_index( def _get_pipe_index(
self, self,
before: Optional[Union[str, int]] = None, before: Optional[Union[str, int]] = None,
@ -1690,6 +1743,7 @@ class Language:
meta: Dict[str, Any] = SimpleFrozenDict(), meta: Dict[str, Any] = SimpleFrozenDict(),
auto_fill: bool = True, auto_fill: bool = True,
validate: bool = True, validate: bool = True,
pipe_instances: Dict[str, Any] = SimpleFrozenDict()
) -> "Language": ) -> "Language":
"""Create the nlp object from a loaded config. Will set up the tokenizer """Create the nlp object from a loaded config. Will set up the tokenizer
and language data, add pipeline components etc. If no config is provided, and language data, add pipeline components etc. If no config is provided,
@ -1765,6 +1819,11 @@ class Language:
# Warn about require_gpu usage in jupyter notebook # Warn about require_gpu usage in jupyter notebook
warn_if_jupyter_cupy() warn_if_jupyter_cupy()
# If we've been passed pipe instances, check whether
# they have a Vocab instance, and if they do, use
# that one. This also performs some additional checks and
# warns if there's a mismatch.
vocab = _get_instantiated_vocab(vocab, pipe_instances)
# Note that we don't load vectors here, instead they get loaded explicitly # Note that we don't load vectors here, instead they get loaded explicitly
# inside stuff like the spacy train function. If we loaded them here, # inside stuff like the spacy train function. If we loaded them here,
@ -1781,6 +1840,11 @@ class Language:
interpolated = filled.interpolate() if not filled.is_interpolated else filled interpolated = filled.interpolate() if not filled.is_interpolated else filled
pipeline = interpolated.get("components", {}) pipeline = interpolated.get("components", {})
sourced = util.get_sourced_components(interpolated) sourced = util.get_sourced_components(interpolated)
# Check for components that aren't in the pipe_instances dict, aren't disabled,
# and aren't built by factory.
missing_components = _find_missing_components(pipeline, pipe_instances, exclude)
if missing_components:
raise ValueError(Errors.E1052.format(", ",join(missing_components)))
# If components are loaded from a source (existing models), we cache # If components are loaded from a source (existing models), we cache
# them here so they're only loaded once # them here so they're only loaded once
source_nlps = {} source_nlps = {}
@ -1790,6 +1854,18 @@ class Language:
if pipe_name not in pipeline: if pipe_name not in pipeline:
opts = ", ".join(pipeline.keys()) opts = ", ".join(pipeline.keys())
raise ValueError(Errors.E956.format(name=pipe_name, opts=opts)) raise ValueError(Errors.E956.format(name=pipe_name, opts=opts))
if pipe_name in pipe_instances:
if pipe_name in exclude:
continue
else:
nlp.add_pipe_instance(
pipe_instances[pipe_name]
)
# Is it important that we instantiate pipes that
# aren't excluded? It seems like we would want
# the exclude check above. I've left it how it
# is though, in case there's some sort of crazy
# load-bearing side-effects someone is relying on?
pipe_cfg = util.copy_config(pipeline[pipe_name]) pipe_cfg = util.copy_config(pipeline[pipe_name])
raw_config = Config(filled["components"][pipe_name]) raw_config = Config(filled["components"][pipe_name])
if pipe_name not in exclude: if pipe_name not in exclude:
@ -2306,3 +2382,36 @@ class _Sender:
if self.count >= self.chunk_size: if self.count >= self.chunk_size:
self.count = 0 self.count = 0
self.send() self.send()
def _get_instantiated_vocab(vocab: Union[bool, Vocab], pipe_instances: Dict[str, Any]) -> Union[bool, Vocab]:
vocab_instances = {}
for name, instance in pipe_instances.items():
if hasattr(instance, "vocab") and isinstance(instance.vocab, Vocab):
vocab_instances[name] = instance.vocab
if not vocab_instances:
return vocab
elif isinstance(vocab, Vocab):
for name, inst_voc in vocab_instances.items():
if inst_voc is not vocab:
warnings.warn(Warnings.W125.format(name=name))
return vocab
else:
resolved_vocab = None
for name, inst_voc in vocab_instances.items():
if resolved_vocab is None:
resolved_vocab = inst_voc
elif inst_voc is not resolved_vocab:
warnings.warn(Warnings.W125.format(name=name))
# This is supposed to only be for the type checker --
# it should be unreachable
assert resolved_vocab is not None
return resolved_vocab
def _find_missing_components(
pipeline: List[str],
pipe_instances: Dict[str, Any],
exclude: List[str]
) -> List[str]:
return [name for name in pipeline if name not in pipe_instances and name not in exclude]