mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 10:46:29 +03:00
Support adding pipeline component by instance
This commit is contained in:
parent
6f821efaf3
commit
aa0d747739
|
@ -1,6 +1,6 @@
|
||||||
from typing import Iterator, Optional, Any, Dict, Callable, Iterable
|
from typing import Iterator, Optional, Any, Dict, Callable, Iterable
|
||||||
from typing import Union, Tuple, List, Set, Pattern, Sequence
|
from typing import Union, Tuple, List, Set, Pattern, Sequence, overload
|
||||||
from typing import NoReturn, TYPE_CHECKING, TypeVar, cast, overload
|
from typing import NoReturn, TYPE_CHECKING, TypeVar, cast
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import random
|
import random
|
||||||
|
@ -52,6 +52,9 @@ DEFAULT_CONFIG = util.load_config(DEFAULT_CONFIG_PATH)
|
||||||
# This is the base config for the [pretraining] block and currently not included
|
# This is the base config for the [pretraining] block and currently not included
|
||||||
# in the main config and only added via the 'init fill-config' command
|
# in the main config and only added via the 'init fill-config' command
|
||||||
DEFAULT_CONFIG_PRETRAIN_PATH = Path(__file__).parent / "default_config_pretraining.cfg"
|
DEFAULT_CONFIG_PRETRAIN_PATH = Path(__file__).parent / "default_config_pretraining.cfg"
|
||||||
|
# Factory name indicating that the component wasn't constructed by a factory,
|
||||||
|
# and was instead passed by instance
|
||||||
|
INSTANCE_FACTORY_NAME = "__added_by_instance__"
|
||||||
|
|
||||||
# Type variable for contexts piped with documents
|
# Type variable for contexts piped with documents
|
||||||
_AnyContext = TypeVar("_AnyContext")
|
_AnyContext = TypeVar("_AnyContext")
|
||||||
|
@ -743,6 +746,9 @@ class Language:
|
||||||
"""Add a component to the processing pipeline. Valid components are
|
"""Add a component to the processing pipeline. Valid components are
|
||||||
callables that take a `Doc` object, modify it and return it. Only one
|
callables that take a `Doc` object, modify it and return it. Only one
|
||||||
of before/after/first/last can be set. Default behaviour is "last".
|
of before/after/first/last can be set. Default behaviour is "last".
|
||||||
|
Components can be added either by factory name or by instance. If
|
||||||
|
an instance is supplied and you serialize the pipeline, you'll need
|
||||||
|
to also pass an instance into spacy.load() to construct the pipeline.
|
||||||
|
|
||||||
factory_name (str): Name of the component factory.
|
factory_name (str): Name of the component factory.
|
||||||
name (str): Name of pipeline component. Overwrites existing
|
name (str): Name of pipeline component. Overwrites existing
|
||||||
|
@ -790,11 +796,58 @@ class Language:
|
||||||
raw_config=raw_config,
|
raw_config=raw_config,
|
||||||
validate=validate,
|
validate=validate,
|
||||||
)
|
)
|
||||||
pipe_index = self._get_pipe_index(before, after, first, last)
|
|
||||||
self._pipe_meta[name] = self.get_factory_meta(factory_name)
|
self._pipe_meta[name] = self.get_factory_meta(factory_name)
|
||||||
|
pipe_index = self._get_pipe_index(before, after, first, last)
|
||||||
self._components.insert(pipe_index, (name, pipe_component))
|
self._components.insert(pipe_index, (name, pipe_component))
|
||||||
return pipe_component
|
return pipe_component
|
||||||
|
|
||||||
|
def add_pipe_instance(self, component: PipeCallable,
|
||||||
|
/, name: Optional[str] = None,
|
||||||
|
*,
|
||||||
|
before: Optional[Union[str, int]] = None,
|
||||||
|
after: Optional[Union[str, int]] = None,
|
||||||
|
first: Optional[bool] = None,
|
||||||
|
last: Optional[bool] = None,
|
||||||
|
) -> PipeCallable:
|
||||||
|
"""Add a component instance to the processing pipeline. Valid components
|
||||||
|
are callables that take a `Doc` object, modify it and return it. Only one
|
||||||
|
of before/after/first/last can be set. Default behaviour is "last".
|
||||||
|
|
||||||
|
A limitation of this method is that spaCy will not know how to reconstruct
|
||||||
|
your pipeline after you save it out (unlike the 'Language.add_pipe()' method,
|
||||||
|
where you provide a config and let spaCy construct the instance). See 'spacy.load'
|
||||||
|
for details of how to load back a pipeline with components added by instance.
|
||||||
|
|
||||||
|
pipe_instance (Callable[[Doc], Doc]): The component to add.
|
||||||
|
name (str): Name of pipeline component. Overwrites existing
|
||||||
|
component.name attribute if available. If no name is set and
|
||||||
|
the component exposes no name attribute, component.__name__ is
|
||||||
|
used. An error is raised if a name already exists in the pipeline.
|
||||||
|
before (Union[str, int]): Name or index of the component to insert new
|
||||||
|
component directly before.
|
||||||
|
after (Union[str, int]): Name or index of the component to insert new
|
||||||
|
component directly after.
|
||||||
|
first (bool): If True, insert component first in the pipeline.
|
||||||
|
last (bool): If True, insert component last in the pipeline.
|
||||||
|
RETURNS (Callable[[Doc], Doc]): The pipeline component.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/language#add_pipe_instance
|
||||||
|
"""
|
||||||
|
name = name if name is not None else getattr(component, "name")
|
||||||
|
if name is None:
|
||||||
|
raise ValueError("TODO error")
|
||||||
|
if name in self.component_names:
|
||||||
|
raise ValueError(Errors.E007.format(name=name, opts=self.component_names))
|
||||||
|
|
||||||
|
# It would be possible to take arguments for the FactoryMeta here, but we'll then have
|
||||||
|
# a problem on deserialization: where will the data be coming from?
|
||||||
|
# I think if someone wants that, they should register a component function.
|
||||||
|
self._pipe_meta[name] = FactoryMeta(INSTANCE_FACTORY_NAME)
|
||||||
|
self._pipe_configs[name] = Config()
|
||||||
|
pipe_index = self._get_pipe_index(before, after, first, last)
|
||||||
|
self._components.insert(pipe_index, (name, component))
|
||||||
|
return component
|
||||||
|
|
||||||
def _get_pipe_index(
|
def _get_pipe_index(
|
||||||
self,
|
self,
|
||||||
before: Optional[Union[str, int]] = None,
|
before: Optional[Union[str, int]] = None,
|
||||||
|
@ -1690,6 +1743,7 @@ class Language:
|
||||||
meta: Dict[str, Any] = SimpleFrozenDict(),
|
meta: Dict[str, Any] = SimpleFrozenDict(),
|
||||||
auto_fill: bool = True,
|
auto_fill: bool = True,
|
||||||
validate: bool = True,
|
validate: bool = True,
|
||||||
|
pipe_instances: Dict[str, Any] = SimpleFrozenDict()
|
||||||
) -> "Language":
|
) -> "Language":
|
||||||
"""Create the nlp object from a loaded config. Will set up the tokenizer
|
"""Create the nlp object from a loaded config. Will set up the tokenizer
|
||||||
and language data, add pipeline components etc. If no config is provided,
|
and language data, add pipeline components etc. If no config is provided,
|
||||||
|
@ -1765,6 +1819,11 @@ class Language:
|
||||||
|
|
||||||
# Warn about require_gpu usage in jupyter notebook
|
# Warn about require_gpu usage in jupyter notebook
|
||||||
warn_if_jupyter_cupy()
|
warn_if_jupyter_cupy()
|
||||||
|
# If we've been passed pipe instances, check whether
|
||||||
|
# they have a Vocab instance, and if they do, use
|
||||||
|
# that one. This also performs some additional checks and
|
||||||
|
# warns if there's a mismatch.
|
||||||
|
vocab = _get_instantiated_vocab(vocab, pipe_instances)
|
||||||
|
|
||||||
# Note that we don't load vectors here, instead they get loaded explicitly
|
# Note that we don't load vectors here, instead they get loaded explicitly
|
||||||
# inside stuff like the spacy train function. If we loaded them here,
|
# inside stuff like the spacy train function. If we loaded them here,
|
||||||
|
@ -1781,6 +1840,11 @@ class Language:
|
||||||
interpolated = filled.interpolate() if not filled.is_interpolated else filled
|
interpolated = filled.interpolate() if not filled.is_interpolated else filled
|
||||||
pipeline = interpolated.get("components", {})
|
pipeline = interpolated.get("components", {})
|
||||||
sourced = util.get_sourced_components(interpolated)
|
sourced = util.get_sourced_components(interpolated)
|
||||||
|
# Check for components that aren't in the pipe_instances dict, aren't disabled,
|
||||||
|
# and aren't built by factory.
|
||||||
|
missing_components = _find_missing_components(pipeline, pipe_instances, exclude)
|
||||||
|
if missing_components:
|
||||||
|
raise ValueError(Errors.E1052.format(", ",join(missing_components)))
|
||||||
# If components are loaded from a source (existing models), we cache
|
# If components are loaded from a source (existing models), we cache
|
||||||
# them here so they're only loaded once
|
# them here so they're only loaded once
|
||||||
source_nlps = {}
|
source_nlps = {}
|
||||||
|
@ -1790,6 +1854,18 @@ class Language:
|
||||||
if pipe_name not in pipeline:
|
if pipe_name not in pipeline:
|
||||||
opts = ", ".join(pipeline.keys())
|
opts = ", ".join(pipeline.keys())
|
||||||
raise ValueError(Errors.E956.format(name=pipe_name, opts=opts))
|
raise ValueError(Errors.E956.format(name=pipe_name, opts=opts))
|
||||||
|
if pipe_name in pipe_instances:
|
||||||
|
if pipe_name in exclude:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
nlp.add_pipe_instance(
|
||||||
|
pipe_instances[pipe_name]
|
||||||
|
)
|
||||||
|
# Is it important that we instantiate pipes that
|
||||||
|
# aren't excluded? It seems like we would want
|
||||||
|
# the exclude check above. I've left it how it
|
||||||
|
# is though, in case there's some sort of crazy
|
||||||
|
# load-bearing side-effects someone is relying on?
|
||||||
pipe_cfg = util.copy_config(pipeline[pipe_name])
|
pipe_cfg = util.copy_config(pipeline[pipe_name])
|
||||||
raw_config = Config(filled["components"][pipe_name])
|
raw_config = Config(filled["components"][pipe_name])
|
||||||
if pipe_name not in exclude:
|
if pipe_name not in exclude:
|
||||||
|
@ -2306,3 +2382,36 @@ class _Sender:
|
||||||
if self.count >= self.chunk_size:
|
if self.count >= self.chunk_size:
|
||||||
self.count = 0
|
self.count = 0
|
||||||
self.send()
|
self.send()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_instantiated_vocab(vocab: Union[bool, Vocab], pipe_instances: Dict[str, Any]) -> Union[bool, Vocab]:
|
||||||
|
vocab_instances = {}
|
||||||
|
for name, instance in pipe_instances.items():
|
||||||
|
if hasattr(instance, "vocab") and isinstance(instance.vocab, Vocab):
|
||||||
|
vocab_instances[name] = instance.vocab
|
||||||
|
if not vocab_instances:
|
||||||
|
return vocab
|
||||||
|
elif isinstance(vocab, Vocab):
|
||||||
|
for name, inst_voc in vocab_instances.items():
|
||||||
|
if inst_voc is not vocab:
|
||||||
|
warnings.warn(Warnings.W125.format(name=name))
|
||||||
|
return vocab
|
||||||
|
else:
|
||||||
|
resolved_vocab = None
|
||||||
|
for name, inst_voc in vocab_instances.items():
|
||||||
|
if resolved_vocab is None:
|
||||||
|
resolved_vocab = inst_voc
|
||||||
|
elif inst_voc is not resolved_vocab:
|
||||||
|
warnings.warn(Warnings.W125.format(name=name))
|
||||||
|
# This is supposed to only be for the type checker --
|
||||||
|
# it should be unreachable
|
||||||
|
assert resolved_vocab is not None
|
||||||
|
return resolved_vocab
|
||||||
|
|
||||||
|
|
||||||
|
def _find_missing_components(
|
||||||
|
pipeline: List[str],
|
||||||
|
pipe_instances: Dict[str, Any],
|
||||||
|
exclude: List[str]
|
||||||
|
) -> List[str]:
|
||||||
|
return [name for name in pipeline if name not in pipe_instances and name not in exclude]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user