mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 01:16:28 +03:00
Replace Pipe type with Callable in Language (#11803)
* Replace Pipe type with Callable in Language * Use Callable[[Doc], Doc] in the docstrings
This commit is contained in:
parent
f1e0243450
commit
6f9d630f7e
|
@ -13,6 +13,7 @@ from ._util import import_code, debug_cli, _format_number
|
|||
from ..training import Example, remove_bilu_prefix
|
||||
from ..training.initialize import get_sourced_components
|
||||
from ..schemas import ConfigSchemaTraining
|
||||
from ..pipeline import TrainablePipe
|
||||
from ..pipeline._parser_internals import nonproj
|
||||
from ..pipeline._parser_internals.nonproj import DELIMITER
|
||||
from ..pipeline import Morphologizer, SpanCategorizer
|
||||
|
@ -934,6 +935,7 @@ def _get_labels_from_model(nlp: Language, factory_name: str) -> Set[str]:
|
|||
labels: Set[str] = set()
|
||||
for pipe_name in pipe_names:
|
||||
pipe = nlp.get_pipe(pipe_name)
|
||||
assert isinstance(pipe, TrainablePipe)
|
||||
labels.update(pipe.labels)
|
||||
return labels
|
||||
|
||||
|
|
|
@ -43,8 +43,7 @@ from .lookups import load_lookups
|
|||
from .compat import Literal
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .pipeline import Pipe # noqa: F401
|
||||
PipeCallable = Callable[[Doc], Doc]
|
||||
|
||||
|
||||
# This is the base config will all settings (training etc.)
|
||||
|
@ -181,7 +180,7 @@ class Language:
|
|||
self.vocab: Vocab = vocab
|
||||
if self.lang is None:
|
||||
self.lang = self.vocab.lang
|
||||
self._components: List[Tuple[str, "Pipe"]] = []
|
||||
self._components: List[Tuple[str, PipeCallable]] = []
|
||||
self._disabled: Set[str] = set()
|
||||
self.max_length = max_length
|
||||
# Create the default tokenizer from the default config
|
||||
|
@ -303,7 +302,7 @@ class Language:
|
|||
return SimpleFrozenList(names)
|
||||
|
||||
@property
|
||||
def components(self) -> List[Tuple[str, "Pipe"]]:
|
||||
def components(self) -> List[Tuple[str, PipeCallable]]:
|
||||
"""Get all (name, component) tuples in the pipeline, including the
|
||||
currently disabled components.
|
||||
"""
|
||||
|
@ -322,12 +321,12 @@ class Language:
|
|||
return SimpleFrozenList(names, error=Errors.E926.format(attr="component_names"))
|
||||
|
||||
@property
|
||||
def pipeline(self) -> List[Tuple[str, "Pipe"]]:
|
||||
def pipeline(self) -> List[Tuple[str, PipeCallable]]:
|
||||
"""The processing pipeline consisting of (name, component) tuples. The
|
||||
components are called on the Doc in order as it passes through the
|
||||
pipeline.
|
||||
|
||||
RETURNS (List[Tuple[str, Pipe]]): The pipeline.
|
||||
RETURNS (List[Tuple[str, Callable[[Doc], Doc]]]): The pipeline.
|
||||
"""
|
||||
pipes = [(n, p) for n, p in self._components if n not in self._disabled]
|
||||
return SimpleFrozenList(pipes, error=Errors.E926.format(attr="pipeline"))
|
||||
|
@ -527,7 +526,7 @@ class Language:
|
|||
assigns: Iterable[str] = SimpleFrozenList(),
|
||||
requires: Iterable[str] = SimpleFrozenList(),
|
||||
retokenizes: bool = False,
|
||||
func: Optional["Pipe"] = None,
|
||||
func: Optional[PipeCallable] = None,
|
||||
) -> Callable[..., Any]:
|
||||
"""Register a new pipeline component. Can be used for stateless function
|
||||
components that don't require a separate factory. Can be used as a
|
||||
|
@ -542,7 +541,7 @@ class Language:
|
|||
e.g. "token.ent_id". Used for pipeline analysis.
|
||||
retokenizes (bool): Whether the component changes the tokenization.
|
||||
Used for pipeline analysis.
|
||||
func (Optional[Callable]): Factory function if not used as a decorator.
|
||||
func (Optional[Callable[[Doc], Doc]): Factory function if not used as a decorator.
|
||||
|
||||
DOCS: https://spacy.io/api/language#component
|
||||
"""
|
||||
|
@ -553,11 +552,11 @@ class Language:
|
|||
raise ValueError(Errors.E853.format(name=name))
|
||||
component_name = name if name is not None else util.get_object_name(func)
|
||||
|
||||
def add_component(component_func: "Pipe") -> Callable:
|
||||
def add_component(component_func: PipeCallable) -> Callable:
|
||||
if isinstance(func, type): # function is a class
|
||||
raise ValueError(Errors.E965.format(name=component_name))
|
||||
|
||||
def factory_func(nlp, name: str) -> "Pipe":
|
||||
def factory_func(nlp, name: str) -> PipeCallable:
|
||||
return component_func
|
||||
|
||||
internal_name = cls.get_factory_name(name)
|
||||
|
@ -607,7 +606,7 @@ class Language:
|
|||
print_pipe_analysis(analysis, keys=keys)
|
||||
return analysis
|
||||
|
||||
def get_pipe(self, name: str) -> "Pipe":
|
||||
def get_pipe(self, name: str) -> PipeCallable:
|
||||
"""Get a pipeline component for a given component name.
|
||||
|
||||
name (str): Name of pipeline component to get.
|
||||
|
@ -628,7 +627,7 @@ class Language:
|
|||
config: Dict[str, Any] = SimpleFrozenDict(),
|
||||
raw_config: Optional[Config] = None,
|
||||
validate: bool = True,
|
||||
) -> "Pipe":
|
||||
) -> PipeCallable:
|
||||
"""Create a pipeline component. Mostly used internally. To create and
|
||||
add a component to the pipeline, you can use nlp.add_pipe.
|
||||
|
||||
|
@ -640,7 +639,7 @@ class Language:
|
|||
raw_config (Optional[Config]): Internals: the non-interpolated config.
|
||||
validate (bool): Whether to validate the component config against the
|
||||
arguments and types expected by the factory.
|
||||
RETURNS (Pipe): The pipeline component.
|
||||
RETURNS (Callable[[Doc], Doc]): The pipeline component.
|
||||
|
||||
DOCS: https://spacy.io/api/language#create_pipe
|
||||
"""
|
||||
|
@ -695,13 +694,13 @@ class Language:
|
|||
|
||||
def create_pipe_from_source(
|
||||
self, source_name: str, source: "Language", *, name: str
|
||||
) -> Tuple["Pipe", str]:
|
||||
) -> Tuple[PipeCallable, str]:
|
||||
"""Create a pipeline component by copying it from an existing model.
|
||||
|
||||
source_name (str): Name of the component in the source pipeline.
|
||||
source (Language): The source nlp object to copy from.
|
||||
name (str): Optional alternative name to use in current pipeline.
|
||||
RETURNS (Tuple[Callable, str]): The component and its factory name.
|
||||
RETURNS (Tuple[Callable[[Doc], Doc], str]): The component and its factory name.
|
||||
"""
|
||||
# Check source type
|
||||
if not isinstance(source, Language):
|
||||
|
@ -740,7 +739,7 @@ class Language:
|
|||
config: Dict[str, Any] = SimpleFrozenDict(),
|
||||
raw_config: Optional[Config] = None,
|
||||
validate: bool = True,
|
||||
) -> "Pipe":
|
||||
) -> PipeCallable:
|
||||
"""Add a component to the processing pipeline. Valid components are
|
||||
callables that take a `Doc` object, modify it and return it. Only one
|
||||
of before/after/first/last can be set. Default behaviour is "last".
|
||||
|
@ -763,7 +762,7 @@ class Language:
|
|||
raw_config (Optional[Config]): Internals: the non-interpolated config.
|
||||
validate (bool): Whether to validate the component config against the
|
||||
arguments and types expected by the factory.
|
||||
RETURNS (Pipe): The pipeline component.
|
||||
RETURNS (Callable[[Doc], Doc]): The pipeline component.
|
||||
|
||||
DOCS: https://spacy.io/api/language#add_pipe
|
||||
"""
|
||||
|
@ -869,7 +868,7 @@ class Language:
|
|||
*,
|
||||
config: Dict[str, Any] = SimpleFrozenDict(),
|
||||
validate: bool = True,
|
||||
) -> "Pipe":
|
||||
) -> PipeCallable:
|
||||
"""Replace a component in the pipeline.
|
||||
|
||||
name (str): Name of the component to replace.
|
||||
|
@ -878,7 +877,7 @@ class Language:
|
|||
component. Will be merged with default config, if available.
|
||||
validate (bool): Whether to validate the component config against the
|
||||
arguments and types expected by the factory.
|
||||
RETURNS (Pipe): The new pipeline component.
|
||||
RETURNS (Callable[[Doc], Doc]): The new pipeline component.
|
||||
|
||||
DOCS: https://spacy.io/api/language#replace_pipe
|
||||
"""
|
||||
|
@ -930,11 +929,11 @@ class Language:
|
|||
init_cfg = self._config["initialize"]["components"].pop(old_name)
|
||||
self._config["initialize"]["components"][new_name] = init_cfg
|
||||
|
||||
def remove_pipe(self, name: str) -> Tuple[str, "Pipe"]:
|
||||
def remove_pipe(self, name: str) -> Tuple[str, PipeCallable]:
|
||||
"""Remove a component from the pipeline.
|
||||
|
||||
name (str): Name of the component to remove.
|
||||
RETURNS (tuple): A `(name, component)` tuple of the removed component.
|
||||
RETURNS (Tuple[str, Callable[[Doc], Doc]]): A `(name, component)` tuple of the removed component.
|
||||
|
||||
DOCS: https://spacy.io/api/language#remove_pipe
|
||||
"""
|
||||
|
@ -1349,15 +1348,15 @@ class Language:
|
|||
|
||||
def set_error_handler(
|
||||
self,
|
||||
error_handler: Callable[[str, "Pipe", List[Doc], Exception], NoReturn],
|
||||
error_handler: Callable[[str, PipeCallable, List[Doc], Exception], NoReturn],
|
||||
):
|
||||
"""Set an error handler object for all the components in the pipeline that implement
|
||||
a set_error_handler function.
|
||||
"""Set an error handler object for all the components in the pipeline
|
||||
that implement a set_error_handler function.
|
||||
|
||||
error_handler (Callable[[str, Pipe, List[Doc], Exception], NoReturn]):
|
||||
Function that deals with a failing batch of documents. This callable function should take in
|
||||
the component's name, the component itself, the offending batch of documents, and the exception
|
||||
that was thrown.
|
||||
error_handler (Callable[[str, Callable[[Doc], Doc], List[Doc], Exception], NoReturn]):
|
||||
Function that deals with a failing batch of documents. This callable
|
||||
function should take in the component's name, the component itself,
|
||||
the offending batch of documents, and the exception that was thrown.
|
||||
DOCS: https://spacy.io/api/language#set_error_handler
|
||||
"""
|
||||
self.default_error_handler = error_handler
|
||||
|
|
|
@ -838,8 +838,8 @@ def test_textcat_loss(multi_label: bool, expected_loss: float):
|
|||
textcat = nlp.add_pipe("textcat_multilabel")
|
||||
else:
|
||||
textcat = nlp.add_pipe("textcat")
|
||||
textcat.initialize(lambda: train_examples)
|
||||
assert isinstance(textcat, TextCategorizer)
|
||||
textcat.initialize(lambda: train_examples)
|
||||
scores = textcat.model.ops.asarray(
|
||||
[[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 1.0, 1.0]], dtype="f" # type: ignore
|
||||
)
|
||||
|
|
|
@ -51,8 +51,7 @@ from . import about
|
|||
|
||||
if TYPE_CHECKING:
|
||||
# This lets us add type hints for mypy etc. without causing circular imports
|
||||
from .language import Language # noqa: F401
|
||||
from .pipeline import Pipe # noqa: F401
|
||||
from .language import Language, PipeCallable # noqa: F401
|
||||
from .tokens import Doc, Span # noqa: F401
|
||||
from .vocab import Vocab # noqa: F401
|
||||
|
||||
|
@ -1642,9 +1641,9 @@ def check_bool_env_var(env_var: str) -> bool:
|
|||
|
||||
def _pipe(
|
||||
docs: Iterable["Doc"],
|
||||
proc: "Pipe",
|
||||
proc: "PipeCallable",
|
||||
name: str,
|
||||
default_error_handler: Callable[[str, "Pipe", List["Doc"], Exception], NoReturn],
|
||||
default_error_handler: Callable[[str, "PipeCallable", List["Doc"], Exception], NoReturn],
|
||||
kwargs: Mapping[str, Any],
|
||||
) -> Iterator["Doc"]:
|
||||
if hasattr(proc, "pipe"):
|
||||
|
|
Loading…
Reference in New Issue
Block a user