mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Support adding pipeline component by instance
This commit is contained in:
		
							parent
							
								
									6f821efaf3
								
							
						
					
					
						commit
						aa0d747739
					
				| 
						 | 
					@ -1,6 +1,6 @@
 | 
				
			||||||
from typing import Iterator, Optional, Any, Dict, Callable, Iterable
 | 
					from typing import Iterator, Optional, Any, Dict, Callable, Iterable
 | 
				
			||||||
from typing import Union, Tuple, List, Set, Pattern, Sequence
 | 
					from typing import Union, Tuple, List, Set, Pattern, Sequence, overload
 | 
				
			||||||
from typing import NoReturn, TYPE_CHECKING, TypeVar, cast, overload
 | 
					from typing import NoReturn, TYPE_CHECKING, TypeVar, cast
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from dataclasses import dataclass
 | 
					from dataclasses import dataclass
 | 
				
			||||||
import random
 | 
					import random
 | 
				
			||||||
| 
						 | 
					@ -52,6 +52,9 @@ DEFAULT_CONFIG = util.load_config(DEFAULT_CONFIG_PATH)
 | 
				
			||||||
# This is the base config for the [pretraining] block and currently not included
 | 
					# This is the base config for the [pretraining] block and currently not included
 | 
				
			||||||
# in the main config and only added via the 'init fill-config' command
 | 
					# in the main config and only added via the 'init fill-config' command
 | 
				
			||||||
DEFAULT_CONFIG_PRETRAIN_PATH = Path(__file__).parent / "default_config_pretraining.cfg"
 | 
					DEFAULT_CONFIG_PRETRAIN_PATH = Path(__file__).parent / "default_config_pretraining.cfg"
 | 
				
			||||||
 | 
					# Factory name indicating that the component wasn't constructed by a factory,
 | 
				
			||||||
 | 
					# and was instead passed by instance
 | 
				
			||||||
 | 
					INSTANCE_FACTORY_NAME = "__added_by_instance__"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Type variable for contexts piped with documents
 | 
					# Type variable for contexts piped with documents
 | 
				
			||||||
_AnyContext = TypeVar("_AnyContext")
 | 
					_AnyContext = TypeVar("_AnyContext")
 | 
				
			||||||
| 
						 | 
					@ -743,6 +746,9 @@ class Language:
 | 
				
			||||||
        """Add a component to the processing pipeline. Valid components are
 | 
					        """Add a component to the processing pipeline. Valid components are
 | 
				
			||||||
        callables that take a `Doc` object, modify it and return it. Only one
 | 
					        callables that take a `Doc` object, modify it and return it. Only one
 | 
				
			||||||
        of before/after/first/last can be set. Default behaviour is "last".
 | 
					        of before/after/first/last can be set. Default behaviour is "last".
 | 
				
			||||||
 | 
					        Components can be added either by factory name or by instance. If
 | 
				
			||||||
 | 
					        an instance is supplied and you serialize the pipeline, you'll need
 | 
				
			||||||
 | 
					        to also pass an instance into spacy.load() to construct the pipeline.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        factory_name (str): Name of the component factory.
 | 
					        factory_name (str): Name of the component factory.
 | 
				
			||||||
        name (str): Name of pipeline component. Overwrites existing
 | 
					        name (str): Name of pipeline component. Overwrites existing
 | 
				
			||||||
| 
						 | 
					@ -790,11 +796,58 @@ class Language:
 | 
				
			||||||
                raw_config=raw_config,
 | 
					                raw_config=raw_config,
 | 
				
			||||||
                validate=validate,
 | 
					                validate=validate,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        pipe_index = self._get_pipe_index(before, after, first, last)
 | 
					 | 
				
			||||||
        self._pipe_meta[name] = self.get_factory_meta(factory_name)
 | 
					        self._pipe_meta[name] = self.get_factory_meta(factory_name)
 | 
				
			||||||
 | 
					        pipe_index = self._get_pipe_index(before, after, first, last)
 | 
				
			||||||
        self._components.insert(pipe_index, (name, pipe_component))
 | 
					        self._components.insert(pipe_index, (name, pipe_component))
 | 
				
			||||||
        return pipe_component
 | 
					        return pipe_component
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_pipe_instance(self, component: PipeCallable,
 | 
				
			||||||
 | 
					        /, name: Optional[str] = None,
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        before: Optional[Union[str, int]] = None,
 | 
				
			||||||
 | 
					        after: Optional[Union[str, int]] = None,
 | 
				
			||||||
 | 
					        first: Optional[bool] = None,
 | 
				
			||||||
 | 
					        last: Optional[bool] = None,
 | 
				
			||||||
 | 
					    ) -> PipeCallable:
 | 
				
			||||||
 | 
					        """Add a component instance to the processing pipeline. Valid components
 | 
				
			||||||
 | 
					        are callables that take a `Doc` object, modify it and return it. Only one
 | 
				
			||||||
 | 
					        of before/after/first/last can be set. Default behaviour is "last".
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        A limitation of this method is that spaCy will not know how to reconstruct
 | 
				
			||||||
 | 
					        your pipeline after you save it out (unlike the 'Language.add_pipe()' method,
 | 
				
			||||||
 | 
					        where you provide a config and let spaCy construct the instance). See 'spacy.load'
 | 
				
			||||||
 | 
					        for details of how to load back a pipeline with components added by instance.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        pipe_instance (Callable[[Doc], Doc]): The component to add.
 | 
				
			||||||
 | 
					        name (str): Name of pipeline component. Overwrites existing
 | 
				
			||||||
 | 
					            component.name attribute if available. If no name is set and
 | 
				
			||||||
 | 
					            the component exposes no name attribute, component.__name__ is
 | 
				
			||||||
 | 
					            used. An error is raised if a name already exists in the pipeline.
 | 
				
			||||||
 | 
					        before (Union[str, int]): Name or index of the component to insert new
 | 
				
			||||||
 | 
					            component directly before.
 | 
				
			||||||
 | 
					        after (Union[str, int]): Name or index of the component to insert new
 | 
				
			||||||
 | 
					            component directly after.
 | 
				
			||||||
 | 
					        first (bool): If True, insert component first in the pipeline.
 | 
				
			||||||
 | 
					        last (bool): If True, insert component last in the pipeline.
 | 
				
			||||||
 | 
					        RETURNS (Callable[[Doc], Doc]): The pipeline component.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        DOCS: https://spacy.io/api/language#add_pipe_instance
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        name = name if name is not None else getattr(component, "name")
 | 
				
			||||||
 | 
					        if name is None:
 | 
				
			||||||
 | 
					            raise ValueError("TODO error")
 | 
				
			||||||
 | 
					        if name in self.component_names:
 | 
				
			||||||
 | 
					            raise ValueError(Errors.E007.format(name=name, opts=self.component_names))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # It would be possible to take arguments for the FactoryMeta here, but we'll then have
 | 
				
			||||||
 | 
					        # a problem on deserialization: where will the data be coming from?
 | 
				
			||||||
 | 
					        # I think if someone wants that, they should register a component function.
 | 
				
			||||||
 | 
					        self._pipe_meta[name] = FactoryMeta(INSTANCE_FACTORY_NAME)
 | 
				
			||||||
 | 
					        self._pipe_configs[name] = Config()
 | 
				
			||||||
 | 
					        pipe_index = self._get_pipe_index(before, after, first, last)
 | 
				
			||||||
 | 
					        self._components.insert(pipe_index, (name, component))
 | 
				
			||||||
 | 
					        return component
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _get_pipe_index(
 | 
					    def _get_pipe_index(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        before: Optional[Union[str, int]] = None,
 | 
					        before: Optional[Union[str, int]] = None,
 | 
				
			||||||
| 
						 | 
					@ -1690,6 +1743,7 @@ class Language:
 | 
				
			||||||
        meta: Dict[str, Any] = SimpleFrozenDict(),
 | 
					        meta: Dict[str, Any] = SimpleFrozenDict(),
 | 
				
			||||||
        auto_fill: bool = True,
 | 
					        auto_fill: bool = True,
 | 
				
			||||||
        validate: bool = True,
 | 
					        validate: bool = True,
 | 
				
			||||||
 | 
					        pipe_instances: Dict[str, Any] = SimpleFrozenDict()
 | 
				
			||||||
    ) -> "Language":
 | 
					    ) -> "Language":
 | 
				
			||||||
        """Create the nlp object from a loaded config. Will set up the tokenizer
 | 
					        """Create the nlp object from a loaded config. Will set up the tokenizer
 | 
				
			||||||
        and language data, add pipeline components etc. If no config is provided,
 | 
					        and language data, add pipeline components etc. If no config is provided,
 | 
				
			||||||
| 
						 | 
					@ -1765,6 +1819,11 @@ class Language:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Warn about require_gpu usage in jupyter notebook
 | 
					        # Warn about require_gpu usage in jupyter notebook
 | 
				
			||||||
        warn_if_jupyter_cupy()
 | 
					        warn_if_jupyter_cupy()
 | 
				
			||||||
 | 
					        # If we've been passed pipe instances, check whether
 | 
				
			||||||
 | 
					        # they have a Vocab instance, and if they do, use
 | 
				
			||||||
 | 
					        # that one. This also performs some additional checks and
 | 
				
			||||||
 | 
					        # warns if there's a mismatch.
 | 
				
			||||||
 | 
					        vocab = _get_instantiated_vocab(vocab, pipe_instances)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Note that we don't load vectors here, instead they get loaded explicitly
 | 
					        # Note that we don't load vectors here, instead they get loaded explicitly
 | 
				
			||||||
        # inside stuff like the spacy train function. If we loaded them here,
 | 
					        # inside stuff like the spacy train function. If we loaded them here,
 | 
				
			||||||
| 
						 | 
					@ -1781,6 +1840,11 @@ class Language:
 | 
				
			||||||
        interpolated = filled.interpolate() if not filled.is_interpolated else filled
 | 
					        interpolated = filled.interpolate() if not filled.is_interpolated else filled
 | 
				
			||||||
        pipeline = interpolated.get("components", {})
 | 
					        pipeline = interpolated.get("components", {})
 | 
				
			||||||
        sourced = util.get_sourced_components(interpolated)
 | 
					        sourced = util.get_sourced_components(interpolated)
 | 
				
			||||||
 | 
					        # Check for components that aren't in the pipe_instances dict, aren't disabled,
 | 
				
			||||||
 | 
					        # and aren't built by factory.
 | 
				
			||||||
 | 
					        missing_components = _find_missing_components(pipeline, pipe_instances, exclude)
 | 
				
			||||||
 | 
					        if missing_components:
 | 
				
			||||||
 | 
					            raise ValueError(Errors.E1052.format(", ",join(missing_components)))
 | 
				
			||||||
        # If components are loaded from a source (existing models), we cache
 | 
					        # If components are loaded from a source (existing models), we cache
 | 
				
			||||||
        # them here so they're only loaded once
 | 
					        # them here so they're only loaded once
 | 
				
			||||||
        source_nlps = {}
 | 
					        source_nlps = {}
 | 
				
			||||||
| 
						 | 
					@ -1790,6 +1854,18 @@ class Language:
 | 
				
			||||||
            if pipe_name not in pipeline:
 | 
					            if pipe_name not in pipeline:
 | 
				
			||||||
                opts = ", ".join(pipeline.keys())
 | 
					                opts = ", ".join(pipeline.keys())
 | 
				
			||||||
                raise ValueError(Errors.E956.format(name=pipe_name, opts=opts))
 | 
					                raise ValueError(Errors.E956.format(name=pipe_name, opts=opts))
 | 
				
			||||||
 | 
					            if pipe_name in pipe_instances:
 | 
				
			||||||
 | 
					                if pipe_name in exclude:
 | 
				
			||||||
 | 
					                    continue
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    nlp.add_pipe_instance(
 | 
				
			||||||
 | 
					                        pipe_instances[pipe_name]
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					            # Is it important that we instantiate pipes that
 | 
				
			||||||
 | 
					            # aren't excluded? It seems like we would want
 | 
				
			||||||
 | 
					            # the exclude check above. I've left it how it
 | 
				
			||||||
 | 
					            # is though, in case there's some sort of crazy
 | 
				
			||||||
 | 
					            # load-bearing side-effects someone is relying on?
 | 
				
			||||||
            pipe_cfg = util.copy_config(pipeline[pipe_name])
 | 
					            pipe_cfg = util.copy_config(pipeline[pipe_name])
 | 
				
			||||||
            raw_config = Config(filled["components"][pipe_name])
 | 
					            raw_config = Config(filled["components"][pipe_name])
 | 
				
			||||||
            if pipe_name not in exclude:
 | 
					            if pipe_name not in exclude:
 | 
				
			||||||
| 
						 | 
					@ -2306,3 +2382,36 @@ class _Sender:
 | 
				
			||||||
        if self.count >= self.chunk_size:
 | 
					        if self.count >= self.chunk_size:
 | 
				
			||||||
            self.count = 0
 | 
					            self.count = 0
 | 
				
			||||||
            self.send()
 | 
					            self.send()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_instantiated_vocab(vocab: Union[bool, Vocab], pipe_instances: Dict[str, Any]) -> Union[bool, Vocab]:
 | 
				
			||||||
 | 
					    vocab_instances = {}
 | 
				
			||||||
 | 
					    for name, instance in pipe_instances.items():
 | 
				
			||||||
 | 
					        if hasattr(instance, "vocab") and isinstance(instance.vocab, Vocab):
 | 
				
			||||||
 | 
					            vocab_instances[name] = instance.vocab
 | 
				
			||||||
 | 
					    if not vocab_instances:
 | 
				
			||||||
 | 
					        return vocab
 | 
				
			||||||
 | 
					    elif isinstance(vocab, Vocab):
 | 
				
			||||||
 | 
					        for name, inst_voc in vocab_instances.items():
 | 
				
			||||||
 | 
					            if inst_voc is not vocab:
 | 
				
			||||||
 | 
					                warnings.warn(Warnings.W125.format(name=name))
 | 
				
			||||||
 | 
					        return vocab
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        resolved_vocab = None
 | 
				
			||||||
 | 
					        for name, inst_voc in vocab_instances.items():
 | 
				
			||||||
 | 
					            if resolved_vocab is None:
 | 
				
			||||||
 | 
					                resolved_vocab = inst_voc
 | 
				
			||||||
 | 
					            elif inst_voc is not resolved_vocab:
 | 
				
			||||||
 | 
					                warnings.warn(Warnings.W125.format(name=name))
 | 
				
			||||||
 | 
					        # This is supposed to only be for the type checker --
 | 
				
			||||||
 | 
					        # it should be unreachable
 | 
				
			||||||
 | 
					        assert resolved_vocab is not None
 | 
				
			||||||
 | 
					        return resolved_vocab
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _find_missing_components(
 | 
				
			||||||
 | 
					    pipeline: List[str],
 | 
				
			||||||
 | 
					    pipe_instances: Dict[str, Any],
 | 
				
			||||||
 | 
					    exclude: List[str]
 | 
				
			||||||
 | 
					) -> List[str]:
 | 
				
			||||||
 | 
					    return [name for name in pipeline if name not in pipe_instances and name not in exclude]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user