mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Support adding pipeline component by instance
This commit is contained in:
parent
6f821efaf3
commit
aa0d747739
|
@ -1,6 +1,6 @@
|
|||
from typing import Iterator, Optional, Any, Dict, Callable, Iterable
|
||||
from typing import Union, Tuple, List, Set, Pattern, Sequence
|
||||
from typing import NoReturn, TYPE_CHECKING, TypeVar, cast, overload
|
||||
from typing import Union, Tuple, List, Set, Pattern, Sequence, overload
|
||||
from typing import NoReturn, TYPE_CHECKING, TypeVar, cast
|
||||
|
||||
from dataclasses import dataclass
|
||||
import random
|
||||
|
@ -52,6 +52,9 @@ DEFAULT_CONFIG = util.load_config(DEFAULT_CONFIG_PATH)
|
|||
# This is the base config for the [pretraining] block and currently not included
|
||||
# in the main config and only added via the 'init fill-config' command
|
||||
DEFAULT_CONFIG_PRETRAIN_PATH = Path(__file__).parent / "default_config_pretraining.cfg"
|
||||
# Factory name indicating that the component wasn't constructed by a factory,
|
||||
# and was instead passed by instance
|
||||
INSTANCE_FACTORY_NAME = "__added_by_instance__"
|
||||
|
||||
# Type variable for contexts piped with documents
|
||||
_AnyContext = TypeVar("_AnyContext")
|
||||
|
@ -743,6 +746,9 @@ class Language:
|
|||
"""Add a component to the processing pipeline. Valid components are
|
||||
callables that take a `Doc` object, modify it and return it. Only one
|
||||
of before/after/first/last can be set. Default behaviour is "last".
|
||||
Components can be added either by factory name or by instance. If
|
||||
an instance is supplied and you serialize the pipeline, you'll need
|
||||
to also pass an instance into spacy.load() to construct the pipeline.
|
||||
|
||||
factory_name (str): Name of the component factory.
|
||||
name (str): Name of pipeline component. Overwrites existing
|
||||
|
@ -790,11 +796,58 @@ class Language:
|
|||
raw_config=raw_config,
|
||||
validate=validate,
|
||||
)
|
||||
pipe_index = self._get_pipe_index(before, after, first, last)
|
||||
self._pipe_meta[name] = self.get_factory_meta(factory_name)
|
||||
pipe_index = self._get_pipe_index(before, after, first, last)
|
||||
self._components.insert(pipe_index, (name, pipe_component))
|
||||
return pipe_component
|
||||
|
||||
def add_pipe_instance(self, component: PipeCallable,
|
||||
/, name: Optional[str] = None,
|
||||
*,
|
||||
before: Optional[Union[str, int]] = None,
|
||||
after: Optional[Union[str, int]] = None,
|
||||
first: Optional[bool] = None,
|
||||
last: Optional[bool] = None,
|
||||
) -> PipeCallable:
|
||||
"""Add a component instance to the processing pipeline. Valid components
|
||||
are callables that take a `Doc` object, modify it and return it. Only one
|
||||
of before/after/first/last can be set. Default behaviour is "last".
|
||||
|
||||
A limitation of this method is that spaCy will not know how to reconstruct
|
||||
your pipeline after you save it out (unlike the 'Language.add_pipe()' method,
|
||||
where you provide a config and let spaCy construct the instance). See 'spacy.load'
|
||||
for details of how to load back a pipeline with components added by instance.
|
||||
|
||||
pipe_instance (Callable[[Doc], Doc]): The component to add.
|
||||
name (str): Name of pipeline component. Overwrites existing
|
||||
component.name attribute if available. If no name is set and
|
||||
the component exposes no name attribute, component.__name__ is
|
||||
used. An error is raised if a name already exists in the pipeline.
|
||||
before (Union[str, int]): Name or index of the component to insert new
|
||||
component directly before.
|
||||
after (Union[str, int]): Name or index of the component to insert new
|
||||
component directly after.
|
||||
first (bool): If True, insert component first in the pipeline.
|
||||
last (bool): If True, insert component last in the pipeline.
|
||||
RETURNS (Callable[[Doc], Doc]): The pipeline component.
|
||||
|
||||
DOCS: https://spacy.io/api/language#add_pipe_instance
|
||||
"""
|
||||
name = name if name is not None else getattr(component, "name")
|
||||
if name is None:
|
||||
raise ValueError("TODO error")
|
||||
if name in self.component_names:
|
||||
raise ValueError(Errors.E007.format(name=name, opts=self.component_names))
|
||||
|
||||
# It would be possible to take arguments for the FactoryMeta here, but we'll then have
|
||||
# a problem on deserialization: where will the data be coming from?
|
||||
# I think if someone wants that, they should register a component function.
|
||||
self._pipe_meta[name] = FactoryMeta(INSTANCE_FACTORY_NAME)
|
||||
self._pipe_configs[name] = Config()
|
||||
pipe_index = self._get_pipe_index(before, after, first, last)
|
||||
self._components.insert(pipe_index, (name, component))
|
||||
return component
|
||||
|
||||
def _get_pipe_index(
|
||||
self,
|
||||
before: Optional[Union[str, int]] = None,
|
||||
|
@ -1690,6 +1743,7 @@ class Language:
|
|||
meta: Dict[str, Any] = SimpleFrozenDict(),
|
||||
auto_fill: bool = True,
|
||||
validate: bool = True,
|
||||
pipe_instances: Dict[str, Any] = SimpleFrozenDict()
|
||||
) -> "Language":
|
||||
"""Create the nlp object from a loaded config. Will set up the tokenizer
|
||||
and language data, add pipeline components etc. If no config is provided,
|
||||
|
@ -1765,6 +1819,11 @@ class Language:
|
|||
|
||||
# Warn about require_gpu usage in jupyter notebook
|
||||
warn_if_jupyter_cupy()
|
||||
# If we've been passed pipe instances, check whether
|
||||
# they have a Vocab instance, and if they do, use
|
||||
# that one. This also performs some additional checks and
|
||||
# warns if there's a mismatch.
|
||||
vocab = _get_instantiated_vocab(vocab, pipe_instances)
|
||||
|
||||
# Note that we don't load vectors here, instead they get loaded explicitly
|
||||
# inside stuff like the spacy train function. If we loaded them here,
|
||||
|
@ -1781,6 +1840,11 @@ class Language:
|
|||
interpolated = filled.interpolate() if not filled.is_interpolated else filled
|
||||
pipeline = interpolated.get("components", {})
|
||||
sourced = util.get_sourced_components(interpolated)
|
||||
# Check for components that aren't in the pipe_instances dict, aren't disabled,
|
||||
# and aren't built by factory.
|
||||
missing_components = _find_missing_components(pipeline, pipe_instances, exclude)
|
||||
if missing_components:
|
||||
raise ValueError(Errors.E1052.format(", ",join(missing_components)))
|
||||
# If components are loaded from a source (existing models), we cache
|
||||
# them here so they're only loaded once
|
||||
source_nlps = {}
|
||||
|
@ -1790,6 +1854,18 @@ class Language:
|
|||
if pipe_name not in pipeline:
|
||||
opts = ", ".join(pipeline.keys())
|
||||
raise ValueError(Errors.E956.format(name=pipe_name, opts=opts))
|
||||
if pipe_name in pipe_instances:
|
||||
if pipe_name in exclude:
|
||||
continue
|
||||
else:
|
||||
nlp.add_pipe_instance(
|
||||
pipe_instances[pipe_name]
|
||||
)
|
||||
# Is it important that we instantiate pipes that
|
||||
# aren't excluded? It seems like we would want
|
||||
# the exclude check above. I've left it how it
|
||||
# is though, in case there's some sort of crazy
|
||||
# load-bearing side-effects someone is relying on?
|
||||
pipe_cfg = util.copy_config(pipeline[pipe_name])
|
||||
raw_config = Config(filled["components"][pipe_name])
|
||||
if pipe_name not in exclude:
|
||||
|
@ -2306,3 +2382,36 @@ class _Sender:
|
|||
if self.count >= self.chunk_size:
|
||||
self.count = 0
|
||||
self.send()
|
||||
|
||||
|
||||
def _get_instantiated_vocab(vocab: Union[bool, Vocab], pipe_instances: Dict[str, Any]) -> Union[bool, Vocab]:
|
||||
vocab_instances = {}
|
||||
for name, instance in pipe_instances.items():
|
||||
if hasattr(instance, "vocab") and isinstance(instance.vocab, Vocab):
|
||||
vocab_instances[name] = instance.vocab
|
||||
if not vocab_instances:
|
||||
return vocab
|
||||
elif isinstance(vocab, Vocab):
|
||||
for name, inst_voc in vocab_instances.items():
|
||||
if inst_voc is not vocab:
|
||||
warnings.warn(Warnings.W125.format(name=name))
|
||||
return vocab
|
||||
else:
|
||||
resolved_vocab = None
|
||||
for name, inst_voc in vocab_instances.items():
|
||||
if resolved_vocab is None:
|
||||
resolved_vocab = inst_voc
|
||||
elif inst_voc is not resolved_vocab:
|
||||
warnings.warn(Warnings.W125.format(name=name))
|
||||
# This is supposed to only be for the type checker --
|
||||
# it should be unreachable
|
||||
assert resolved_vocab is not None
|
||||
return resolved_vocab
|
||||
|
||||
|
||||
def _find_missing_components(
|
||||
pipeline: List[str],
|
||||
pipe_instances: Dict[str, Any],
|
||||
exclude: List[str]
|
||||
) -> List[str]:
|
||||
return [name for name in pipeline if name not in pipe_instances and name not in exclude]
|
||||
|
|
Loading…
Reference in New Issue
Block a user