From 6f9d630f7e9c11d8d5f7ba37e3764ef3630c172d Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Tue, 29 Nov 2022 13:20:08 +0100 Subject: [PATCH] Replace Pipe type with Callable in Language (#11803) * Replace Pipe type with Callable in Language * Use Callable[[Doc], Doc] in the docstrings --- spacy/cli/debug_data.py | 2 + spacy/language.py | 55 ++++++++++++++-------------- spacy/tests/pipeline/test_textcat.py | 2 +- spacy/util.py | 7 ++-- 4 files changed, 33 insertions(+), 33 deletions(-) diff --git a/spacy/cli/debug_data.py b/spacy/cli/debug_data.py index 963d5b926..a85324e87 100644 --- a/spacy/cli/debug_data.py +++ b/spacy/cli/debug_data.py @@ -13,6 +13,7 @@ from ._util import import_code, debug_cli, _format_number from ..training import Example, remove_bilu_prefix from ..training.initialize import get_sourced_components from ..schemas import ConfigSchemaTraining +from ..pipeline import TrainablePipe from ..pipeline._parser_internals import nonproj from ..pipeline._parser_internals.nonproj import DELIMITER from ..pipeline import Morphologizer, SpanCategorizer @@ -934,6 +935,7 @@ def _get_labels_from_model(nlp: Language, factory_name: str) -> Set[str]: labels: Set[str] = set() for pipe_name in pipe_names: pipe = nlp.get_pipe(pipe_name) + assert isinstance(pipe, TrainablePipe) labels.update(pipe.labels) return labels diff --git a/spacy/language.py b/spacy/language.py index 2789b6690..e0abfd5e7 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -43,8 +43,7 @@ from .lookups import load_lookups from .compat import Literal -if TYPE_CHECKING: - from .pipeline import Pipe # noqa: F401 +PipeCallable = Callable[[Doc], Doc] # This is the base config will all settings (training etc.) @@ -181,7 +180,7 @@ class Language: self.vocab: Vocab = vocab if self.lang is None: self.lang = self.vocab.lang - self._components: List[Tuple[str, "Pipe"]] = [] + self._components: List[Tuple[str, PipeCallable]] = [] self._disabled: Set[str] = set() self.max_length = max_length # Create the default tokenizer from the default config @@ -303,7 +302,7 @@ class Language: return SimpleFrozenList(names) @property - def components(self) -> List[Tuple[str, "Pipe"]]: + def components(self) -> List[Tuple[str, PipeCallable]]: """Get all (name, component) tuples in the pipeline, including the currently disabled components. """ @@ -322,12 +321,12 @@ class Language: return SimpleFrozenList(names, error=Errors.E926.format(attr="component_names")) @property - def pipeline(self) -> List[Tuple[str, "Pipe"]]: + def pipeline(self) -> List[Tuple[str, PipeCallable]]: """The processing pipeline consisting of (name, component) tuples. The components are called on the Doc in order as it passes through the pipeline. - RETURNS (List[Tuple[str, Pipe]]): The pipeline. + RETURNS (List[Tuple[str, Callable[[Doc], Doc]]]): The pipeline. """ pipes = [(n, p) for n, p in self._components if n not in self._disabled] return SimpleFrozenList(pipes, error=Errors.E926.format(attr="pipeline")) @@ -527,7 +526,7 @@ class Language: assigns: Iterable[str] = SimpleFrozenList(), requires: Iterable[str] = SimpleFrozenList(), retokenizes: bool = False, - func: Optional["Pipe"] = None, + func: Optional[PipeCallable] = None, ) -> Callable[..., Any]: """Register a new pipeline component. Can be used for stateless function components that don't require a separate factory. Can be used as a @@ -542,7 +541,7 @@ class Language: e.g. "token.ent_id". Used for pipeline analysis. retokenizes (bool): Whether the component changes the tokenization. Used for pipeline analysis. - func (Optional[Callable]): Factory function if not used as a decorator. + func (Optional[Callable[[Doc], Doc]): Factory function if not used as a decorator. DOCS: https://spacy.io/api/language#component """ @@ -553,11 +552,11 @@ class Language: raise ValueError(Errors.E853.format(name=name)) component_name = name if name is not None else util.get_object_name(func) - def add_component(component_func: "Pipe") -> Callable: + def add_component(component_func: PipeCallable) -> Callable: if isinstance(func, type): # function is a class raise ValueError(Errors.E965.format(name=component_name)) - def factory_func(nlp, name: str) -> "Pipe": + def factory_func(nlp, name: str) -> PipeCallable: return component_func internal_name = cls.get_factory_name(name) @@ -607,7 +606,7 @@ class Language: print_pipe_analysis(analysis, keys=keys) return analysis - def get_pipe(self, name: str) -> "Pipe": + def get_pipe(self, name: str) -> PipeCallable: """Get a pipeline component for a given component name. name (str): Name of pipeline component to get. @@ -628,7 +627,7 @@ class Language: config: Dict[str, Any] = SimpleFrozenDict(), raw_config: Optional[Config] = None, validate: bool = True, - ) -> "Pipe": + ) -> PipeCallable: """Create a pipeline component. Mostly used internally. To create and add a component to the pipeline, you can use nlp.add_pipe. @@ -640,7 +639,7 @@ class Language: raw_config (Optional[Config]): Internals: the non-interpolated config. validate (bool): Whether to validate the component config against the arguments and types expected by the factory. - RETURNS (Pipe): The pipeline component. + RETURNS (Callable[[Doc], Doc]): The pipeline component. DOCS: https://spacy.io/api/language#create_pipe """ @@ -695,13 +694,13 @@ class Language: def create_pipe_from_source( self, source_name: str, source: "Language", *, name: str - ) -> Tuple["Pipe", str]: + ) -> Tuple[PipeCallable, str]: """Create a pipeline component by copying it from an existing model. source_name (str): Name of the component in the source pipeline. source (Language): The source nlp object to copy from. name (str): Optional alternative name to use in current pipeline. - RETURNS (Tuple[Callable, str]): The component and its factory name. + RETURNS (Tuple[Callable[[Doc], Doc], str]): The component and its factory name. """ # Check source type if not isinstance(source, Language): @@ -740,7 +739,7 @@ class Language: config: Dict[str, Any] = SimpleFrozenDict(), raw_config: Optional[Config] = None, validate: bool = True, - ) -> "Pipe": + ) -> PipeCallable: """Add a component to the processing pipeline. Valid components are callables that take a `Doc` object, modify it and return it. Only one of before/after/first/last can be set. Default behaviour is "last". @@ -763,7 +762,7 @@ class Language: raw_config (Optional[Config]): Internals: the non-interpolated config. validate (bool): Whether to validate the component config against the arguments and types expected by the factory. - RETURNS (Pipe): The pipeline component. + RETURNS (Callable[[Doc], Doc]): The pipeline component. DOCS: https://spacy.io/api/language#add_pipe """ @@ -869,7 +868,7 @@ class Language: *, config: Dict[str, Any] = SimpleFrozenDict(), validate: bool = True, - ) -> "Pipe": + ) -> PipeCallable: """Replace a component in the pipeline. name (str): Name of the component to replace. @@ -878,7 +877,7 @@ class Language: component. Will be merged with default config, if available. validate (bool): Whether to validate the component config against the arguments and types expected by the factory. - RETURNS (Pipe): The new pipeline component. + RETURNS (Callable[[Doc], Doc]): The new pipeline component. DOCS: https://spacy.io/api/language#replace_pipe """ @@ -930,11 +929,11 @@ class Language: init_cfg = self._config["initialize"]["components"].pop(old_name) self._config["initialize"]["components"][new_name] = init_cfg - def remove_pipe(self, name: str) -> Tuple[str, "Pipe"]: + def remove_pipe(self, name: str) -> Tuple[str, PipeCallable]: """Remove a component from the pipeline. name (str): Name of the component to remove. - RETURNS (tuple): A `(name, component)` tuple of the removed component. + RETURNS (Tuple[str, Callable[[Doc], Doc]]): A `(name, component)` tuple of the removed component. DOCS: https://spacy.io/api/language#remove_pipe """ @@ -1349,15 +1348,15 @@ class Language: def set_error_handler( self, - error_handler: Callable[[str, "Pipe", List[Doc], Exception], NoReturn], + error_handler: Callable[[str, PipeCallable, List[Doc], Exception], NoReturn], ): - """Set an error handler object for all the components in the pipeline that implement - a set_error_handler function. + """Set an error handler object for all the components in the pipeline + that implement a set_error_handler function. - error_handler (Callable[[str, Pipe, List[Doc], Exception], NoReturn]): - Function that deals with a failing batch of documents. This callable function should take in - the component's name, the component itself, the offending batch of documents, and the exception - that was thrown. + error_handler (Callable[[str, Callable[[Doc], Doc], List[Doc], Exception], NoReturn]): + Function that deals with a failing batch of documents. This callable + function should take in the component's name, the component itself, + the offending batch of documents, and the exception that was thrown. DOCS: https://spacy.io/api/language#set_error_handler """ self.default_error_handler = error_handler diff --git a/spacy/tests/pipeline/test_textcat.py b/spacy/tests/pipeline/test_textcat.py index 2eda9deaf..155ce99a2 100644 --- a/spacy/tests/pipeline/test_textcat.py +++ b/spacy/tests/pipeline/test_textcat.py @@ -838,8 +838,8 @@ def test_textcat_loss(multi_label: bool, expected_loss: float): textcat = nlp.add_pipe("textcat_multilabel") else: textcat = nlp.add_pipe("textcat") - textcat.initialize(lambda: train_examples) assert isinstance(textcat, TextCategorizer) + textcat.initialize(lambda: train_examples) scores = textcat.model.ops.asarray( [[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 1.0, 1.0]], dtype="f" # type: ignore ) diff --git a/spacy/util.py b/spacy/util.py index 76a1e0bfa..cba403361 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -51,8 +51,7 @@ from . import about if TYPE_CHECKING: # This lets us add type hints for mypy etc. without causing circular imports - from .language import Language # noqa: F401 - from .pipeline import Pipe # noqa: F401 + from .language import Language, PipeCallable # noqa: F401 from .tokens import Doc, Span # noqa: F401 from .vocab import Vocab # noqa: F401 @@ -1642,9 +1641,9 @@ def check_bool_env_var(env_var: str) -> bool: def _pipe( docs: Iterable["Doc"], - proc: "Pipe", + proc: "PipeCallable", name: str, - default_error_handler: Callable[[str, "Pipe", List["Doc"], Exception], NoReturn], + default_error_handler: Callable[[str, "PipeCallable", List["Doc"], Exception], NoReturn], kwargs: Mapping[str, Any], ) -> Iterator["Doc"]: if hasattr(proc, "pipe"):