Replace Pipe type with Callable in Language (#11803)

* Replace Pipe type with Callable in Language

* Use Callable[[Doc], Doc] in the docstrings
This commit is contained in:
Adriane Boyd 2022-11-29 13:20:08 +01:00 committed by GitHub
parent f1e0243450
commit 6f9d630f7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 33 additions and 33 deletions

View File

@ -13,6 +13,7 @@ from ._util import import_code, debug_cli, _format_number
from ..training import Example, remove_bilu_prefix from ..training import Example, remove_bilu_prefix
from ..training.initialize import get_sourced_components from ..training.initialize import get_sourced_components
from ..schemas import ConfigSchemaTraining from ..schemas import ConfigSchemaTraining
from ..pipeline import TrainablePipe
from ..pipeline._parser_internals import nonproj from ..pipeline._parser_internals import nonproj
from ..pipeline._parser_internals.nonproj import DELIMITER from ..pipeline._parser_internals.nonproj import DELIMITER
from ..pipeline import Morphologizer, SpanCategorizer from ..pipeline import Morphologizer, SpanCategorizer
@ -934,6 +935,7 @@ def _get_labels_from_model(nlp: Language, factory_name: str) -> Set[str]:
labels: Set[str] = set() labels: Set[str] = set()
for pipe_name in pipe_names: for pipe_name in pipe_names:
pipe = nlp.get_pipe(pipe_name) pipe = nlp.get_pipe(pipe_name)
assert isinstance(pipe, TrainablePipe)
labels.update(pipe.labels) labels.update(pipe.labels)
return labels return labels

View File

@ -43,8 +43,7 @@ from .lookups import load_lookups
from .compat import Literal from .compat import Literal
if TYPE_CHECKING: PipeCallable = Callable[[Doc], Doc]
from .pipeline import Pipe # noqa: F401
# This is the base config will all settings (training etc.) # This is the base config will all settings (training etc.)
@ -181,7 +180,7 @@ class Language:
self.vocab: Vocab = vocab self.vocab: Vocab = vocab
if self.lang is None: if self.lang is None:
self.lang = self.vocab.lang self.lang = self.vocab.lang
self._components: List[Tuple[str, "Pipe"]] = [] self._components: List[Tuple[str, PipeCallable]] = []
self._disabled: Set[str] = set() self._disabled: Set[str] = set()
self.max_length = max_length self.max_length = max_length
# Create the default tokenizer from the default config # Create the default tokenizer from the default config
@ -303,7 +302,7 @@ class Language:
return SimpleFrozenList(names) return SimpleFrozenList(names)
@property @property
def components(self) -> List[Tuple[str, "Pipe"]]: def components(self) -> List[Tuple[str, PipeCallable]]:
"""Get all (name, component) tuples in the pipeline, including the """Get all (name, component) tuples in the pipeline, including the
currently disabled components. currently disabled components.
""" """
@ -322,12 +321,12 @@ class Language:
return SimpleFrozenList(names, error=Errors.E926.format(attr="component_names")) return SimpleFrozenList(names, error=Errors.E926.format(attr="component_names"))
@property @property
def pipeline(self) -> List[Tuple[str, "Pipe"]]: def pipeline(self) -> List[Tuple[str, PipeCallable]]:
"""The processing pipeline consisting of (name, component) tuples. The """The processing pipeline consisting of (name, component) tuples. The
components are called on the Doc in order as it passes through the components are called on the Doc in order as it passes through the
pipeline. pipeline.
RETURNS (List[Tuple[str, Pipe]]): The pipeline. RETURNS (List[Tuple[str, Callable[[Doc], Doc]]]): The pipeline.
""" """
pipes = [(n, p) for n, p in self._components if n not in self._disabled] pipes = [(n, p) for n, p in self._components if n not in self._disabled]
return SimpleFrozenList(pipes, error=Errors.E926.format(attr="pipeline")) return SimpleFrozenList(pipes, error=Errors.E926.format(attr="pipeline"))
@ -527,7 +526,7 @@ class Language:
assigns: Iterable[str] = SimpleFrozenList(), assigns: Iterable[str] = SimpleFrozenList(),
requires: Iterable[str] = SimpleFrozenList(), requires: Iterable[str] = SimpleFrozenList(),
retokenizes: bool = False, retokenizes: bool = False,
func: Optional["Pipe"] = None, func: Optional[PipeCallable] = None,
) -> Callable[..., Any]: ) -> Callable[..., Any]:
"""Register a new pipeline component. Can be used for stateless function """Register a new pipeline component. Can be used for stateless function
components that don't require a separate factory. Can be used as a components that don't require a separate factory. Can be used as a
@ -542,7 +541,7 @@ class Language:
e.g. "token.ent_id". Used for pipeline analysis. e.g. "token.ent_id". Used for pipeline analysis.
retokenizes (bool): Whether the component changes the tokenization. retokenizes (bool): Whether the component changes the tokenization.
Used for pipeline analysis. Used for pipeline analysis.
func (Optional[Callable]): Factory function if not used as a decorator. func (Optional[Callable[[Doc], Doc]): Factory function if not used as a decorator.
DOCS: https://spacy.io/api/language#component DOCS: https://spacy.io/api/language#component
""" """
@ -553,11 +552,11 @@ class Language:
raise ValueError(Errors.E853.format(name=name)) raise ValueError(Errors.E853.format(name=name))
component_name = name if name is not None else util.get_object_name(func) component_name = name if name is not None else util.get_object_name(func)
def add_component(component_func: "Pipe") -> Callable: def add_component(component_func: PipeCallable) -> Callable:
if isinstance(func, type): # function is a class if isinstance(func, type): # function is a class
raise ValueError(Errors.E965.format(name=component_name)) raise ValueError(Errors.E965.format(name=component_name))
def factory_func(nlp, name: str) -> "Pipe": def factory_func(nlp, name: str) -> PipeCallable:
return component_func return component_func
internal_name = cls.get_factory_name(name) internal_name = cls.get_factory_name(name)
@ -607,7 +606,7 @@ class Language:
print_pipe_analysis(analysis, keys=keys) print_pipe_analysis(analysis, keys=keys)
return analysis return analysis
def get_pipe(self, name: str) -> "Pipe": def get_pipe(self, name: str) -> PipeCallable:
"""Get a pipeline component for a given component name. """Get a pipeline component for a given component name.
name (str): Name of pipeline component to get. name (str): Name of pipeline component to get.
@ -628,7 +627,7 @@ class Language:
config: Dict[str, Any] = SimpleFrozenDict(), config: Dict[str, Any] = SimpleFrozenDict(),
raw_config: Optional[Config] = None, raw_config: Optional[Config] = None,
validate: bool = True, validate: bool = True,
) -> "Pipe": ) -> PipeCallable:
"""Create a pipeline component. Mostly used internally. To create and """Create a pipeline component. Mostly used internally. To create and
add a component to the pipeline, you can use nlp.add_pipe. add a component to the pipeline, you can use nlp.add_pipe.
@ -640,7 +639,7 @@ class Language:
raw_config (Optional[Config]): Internals: the non-interpolated config. raw_config (Optional[Config]): Internals: the non-interpolated config.
validate (bool): Whether to validate the component config against the validate (bool): Whether to validate the component config against the
arguments and types expected by the factory. arguments and types expected by the factory.
RETURNS (Pipe): The pipeline component. RETURNS (Callable[[Doc], Doc]): The pipeline component.
DOCS: https://spacy.io/api/language#create_pipe DOCS: https://spacy.io/api/language#create_pipe
""" """
@ -695,13 +694,13 @@ class Language:
def create_pipe_from_source( def create_pipe_from_source(
self, source_name: str, source: "Language", *, name: str self, source_name: str, source: "Language", *, name: str
) -> Tuple["Pipe", str]: ) -> Tuple[PipeCallable, str]:
"""Create a pipeline component by copying it from an existing model. """Create a pipeline component by copying it from an existing model.
source_name (str): Name of the component in the source pipeline. source_name (str): Name of the component in the source pipeline.
source (Language): The source nlp object to copy from. source (Language): The source nlp object to copy from.
name (str): Optional alternative name to use in current pipeline. name (str): Optional alternative name to use in current pipeline.
RETURNS (Tuple[Callable, str]): The component and its factory name. RETURNS (Tuple[Callable[[Doc], Doc], str]): The component and its factory name.
""" """
# Check source type # Check source type
if not isinstance(source, Language): if not isinstance(source, Language):
@ -740,7 +739,7 @@ class Language:
config: Dict[str, Any] = SimpleFrozenDict(), config: Dict[str, Any] = SimpleFrozenDict(),
raw_config: Optional[Config] = None, raw_config: Optional[Config] = None,
validate: bool = True, validate: bool = True,
) -> "Pipe": ) -> PipeCallable:
"""Add a component to the processing pipeline. Valid components are """Add a component to the processing pipeline. Valid components are
callables that take a `Doc` object, modify it and return it. Only one callables that take a `Doc` object, modify it and return it. Only one
of before/after/first/last can be set. Default behaviour is "last". of before/after/first/last can be set. Default behaviour is "last".
@ -763,7 +762,7 @@ class Language:
raw_config (Optional[Config]): Internals: the non-interpolated config. raw_config (Optional[Config]): Internals: the non-interpolated config.
validate (bool): Whether to validate the component config against the validate (bool): Whether to validate the component config against the
arguments and types expected by the factory. arguments and types expected by the factory.
RETURNS (Pipe): The pipeline component. RETURNS (Callable[[Doc], Doc]): The pipeline component.
DOCS: https://spacy.io/api/language#add_pipe DOCS: https://spacy.io/api/language#add_pipe
""" """
@ -869,7 +868,7 @@ class Language:
*, *,
config: Dict[str, Any] = SimpleFrozenDict(), config: Dict[str, Any] = SimpleFrozenDict(),
validate: bool = True, validate: bool = True,
) -> "Pipe": ) -> PipeCallable:
"""Replace a component in the pipeline. """Replace a component in the pipeline.
name (str): Name of the component to replace. name (str): Name of the component to replace.
@ -878,7 +877,7 @@ class Language:
component. Will be merged with default config, if available. component. Will be merged with default config, if available.
validate (bool): Whether to validate the component config against the validate (bool): Whether to validate the component config against the
arguments and types expected by the factory. arguments and types expected by the factory.
RETURNS (Pipe): The new pipeline component. RETURNS (Callable[[Doc], Doc]): The new pipeline component.
DOCS: https://spacy.io/api/language#replace_pipe DOCS: https://spacy.io/api/language#replace_pipe
""" """
@ -930,11 +929,11 @@ class Language:
init_cfg = self._config["initialize"]["components"].pop(old_name) init_cfg = self._config["initialize"]["components"].pop(old_name)
self._config["initialize"]["components"][new_name] = init_cfg self._config["initialize"]["components"][new_name] = init_cfg
def remove_pipe(self, name: str) -> Tuple[str, "Pipe"]: def remove_pipe(self, name: str) -> Tuple[str, PipeCallable]:
"""Remove a component from the pipeline. """Remove a component from the pipeline.
name (str): Name of the component to remove. name (str): Name of the component to remove.
RETURNS (tuple): A `(name, component)` tuple of the removed component. RETURNS (Tuple[str, Callable[[Doc], Doc]]): A `(name, component)` tuple of the removed component.
DOCS: https://spacy.io/api/language#remove_pipe DOCS: https://spacy.io/api/language#remove_pipe
""" """
@ -1349,15 +1348,15 @@ class Language:
def set_error_handler( def set_error_handler(
self, self,
error_handler: Callable[[str, "Pipe", List[Doc], Exception], NoReturn], error_handler: Callable[[str, PipeCallable, List[Doc], Exception], NoReturn],
): ):
"""Set an error handler object for all the components in the pipeline that implement """Set an error handler object for all the components in the pipeline
a set_error_handler function. that implement a set_error_handler function.
error_handler (Callable[[str, Pipe, List[Doc], Exception], NoReturn]): error_handler (Callable[[str, Callable[[Doc], Doc], List[Doc], Exception], NoReturn]):
Function that deals with a failing batch of documents. This callable function should take in Function that deals with a failing batch of documents. This callable
the component's name, the component itself, the offending batch of documents, and the exception function should take in the component's name, the component itself,
that was thrown. the offending batch of documents, and the exception that was thrown.
DOCS: https://spacy.io/api/language#set_error_handler DOCS: https://spacy.io/api/language#set_error_handler
""" """
self.default_error_handler = error_handler self.default_error_handler = error_handler

View File

@ -838,8 +838,8 @@ def test_textcat_loss(multi_label: bool, expected_loss: float):
textcat = nlp.add_pipe("textcat_multilabel") textcat = nlp.add_pipe("textcat_multilabel")
else: else:
textcat = nlp.add_pipe("textcat") textcat = nlp.add_pipe("textcat")
textcat.initialize(lambda: train_examples)
assert isinstance(textcat, TextCategorizer) assert isinstance(textcat, TextCategorizer)
textcat.initialize(lambda: train_examples)
scores = textcat.model.ops.asarray( scores = textcat.model.ops.asarray(
[[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 1.0, 1.0]], dtype="f" # type: ignore [[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 1.0, 1.0]], dtype="f" # type: ignore
) )

View File

@ -51,8 +51,7 @@ from . import about
if TYPE_CHECKING: if TYPE_CHECKING:
# This lets us add type hints for mypy etc. without causing circular imports # This lets us add type hints for mypy etc. without causing circular imports
from .language import Language # noqa: F401 from .language import Language, PipeCallable # noqa: F401
from .pipeline import Pipe # noqa: F401
from .tokens import Doc, Span # noqa: F401 from .tokens import Doc, Span # noqa: F401
from .vocab import Vocab # noqa: F401 from .vocab import Vocab # noqa: F401
@ -1642,9 +1641,9 @@ def check_bool_env_var(env_var: str) -> bool:
def _pipe( def _pipe(
docs: Iterable["Doc"], docs: Iterable["Doc"],
proc: "Pipe", proc: "PipeCallable",
name: str, name: str,
default_error_handler: Callable[[str, "Pipe", List["Doc"], Exception], NoReturn], default_error_handler: Callable[[str, "PipeCallable", List["Doc"], Exception], NoReturn],
kwargs: Mapping[str, Any], kwargs: Mapping[str, Any],
) -> Iterator["Doc"]: ) -> Iterator["Doc"]:
if hasattr(proc, "pipe"): if hasattr(proc, "pipe"):