mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 09:56:28 +03:00
Replace Pipe type with Callable in Language (#11803)
* Replace Pipe type with Callable in Language * Use Callable[[Doc], Doc] in the docstrings
This commit is contained in:
parent
f1e0243450
commit
6f9d630f7e
|
@ -13,6 +13,7 @@ from ._util import import_code, debug_cli, _format_number
|
||||||
from ..training import Example, remove_bilu_prefix
|
from ..training import Example, remove_bilu_prefix
|
||||||
from ..training.initialize import get_sourced_components
|
from ..training.initialize import get_sourced_components
|
||||||
from ..schemas import ConfigSchemaTraining
|
from ..schemas import ConfigSchemaTraining
|
||||||
|
from ..pipeline import TrainablePipe
|
||||||
from ..pipeline._parser_internals import nonproj
|
from ..pipeline._parser_internals import nonproj
|
||||||
from ..pipeline._parser_internals.nonproj import DELIMITER
|
from ..pipeline._parser_internals.nonproj import DELIMITER
|
||||||
from ..pipeline import Morphologizer, SpanCategorizer
|
from ..pipeline import Morphologizer, SpanCategorizer
|
||||||
|
@ -934,6 +935,7 @@ def _get_labels_from_model(nlp: Language, factory_name: str) -> Set[str]:
|
||||||
labels: Set[str] = set()
|
labels: Set[str] = set()
|
||||||
for pipe_name in pipe_names:
|
for pipe_name in pipe_names:
|
||||||
pipe = nlp.get_pipe(pipe_name)
|
pipe = nlp.get_pipe(pipe_name)
|
||||||
|
assert isinstance(pipe, TrainablePipe)
|
||||||
labels.update(pipe.labels)
|
labels.update(pipe.labels)
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
|
|
@ -43,8 +43,7 @@ from .lookups import load_lookups
|
||||||
from .compat import Literal
|
from .compat import Literal
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
PipeCallable = Callable[[Doc], Doc]
|
||||||
from .pipeline import Pipe # noqa: F401
|
|
||||||
|
|
||||||
|
|
||||||
# This is the base config will all settings (training etc.)
|
# This is the base config will all settings (training etc.)
|
||||||
|
@ -181,7 +180,7 @@ class Language:
|
||||||
self.vocab: Vocab = vocab
|
self.vocab: Vocab = vocab
|
||||||
if self.lang is None:
|
if self.lang is None:
|
||||||
self.lang = self.vocab.lang
|
self.lang = self.vocab.lang
|
||||||
self._components: List[Tuple[str, "Pipe"]] = []
|
self._components: List[Tuple[str, PipeCallable]] = []
|
||||||
self._disabled: Set[str] = set()
|
self._disabled: Set[str] = set()
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
# Create the default tokenizer from the default config
|
# Create the default tokenizer from the default config
|
||||||
|
@ -303,7 +302,7 @@ class Language:
|
||||||
return SimpleFrozenList(names)
|
return SimpleFrozenList(names)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def components(self) -> List[Tuple[str, "Pipe"]]:
|
def components(self) -> List[Tuple[str, PipeCallable]]:
|
||||||
"""Get all (name, component) tuples in the pipeline, including the
|
"""Get all (name, component) tuples in the pipeline, including the
|
||||||
currently disabled components.
|
currently disabled components.
|
||||||
"""
|
"""
|
||||||
|
@ -322,12 +321,12 @@ class Language:
|
||||||
return SimpleFrozenList(names, error=Errors.E926.format(attr="component_names"))
|
return SimpleFrozenList(names, error=Errors.E926.format(attr="component_names"))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pipeline(self) -> List[Tuple[str, "Pipe"]]:
|
def pipeline(self) -> List[Tuple[str, PipeCallable]]:
|
||||||
"""The processing pipeline consisting of (name, component) tuples. The
|
"""The processing pipeline consisting of (name, component) tuples. The
|
||||||
components are called on the Doc in order as it passes through the
|
components are called on the Doc in order as it passes through the
|
||||||
pipeline.
|
pipeline.
|
||||||
|
|
||||||
RETURNS (List[Tuple[str, Pipe]]): The pipeline.
|
RETURNS (List[Tuple[str, Callable[[Doc], Doc]]]): The pipeline.
|
||||||
"""
|
"""
|
||||||
pipes = [(n, p) for n, p in self._components if n not in self._disabled]
|
pipes = [(n, p) for n, p in self._components if n not in self._disabled]
|
||||||
return SimpleFrozenList(pipes, error=Errors.E926.format(attr="pipeline"))
|
return SimpleFrozenList(pipes, error=Errors.E926.format(attr="pipeline"))
|
||||||
|
@ -527,7 +526,7 @@ class Language:
|
||||||
assigns: Iterable[str] = SimpleFrozenList(),
|
assigns: Iterable[str] = SimpleFrozenList(),
|
||||||
requires: Iterable[str] = SimpleFrozenList(),
|
requires: Iterable[str] = SimpleFrozenList(),
|
||||||
retokenizes: bool = False,
|
retokenizes: bool = False,
|
||||||
func: Optional["Pipe"] = None,
|
func: Optional[PipeCallable] = None,
|
||||||
) -> Callable[..., Any]:
|
) -> Callable[..., Any]:
|
||||||
"""Register a new pipeline component. Can be used for stateless function
|
"""Register a new pipeline component. Can be used for stateless function
|
||||||
components that don't require a separate factory. Can be used as a
|
components that don't require a separate factory. Can be used as a
|
||||||
|
@ -542,7 +541,7 @@ class Language:
|
||||||
e.g. "token.ent_id". Used for pipeline analysis.
|
e.g. "token.ent_id". Used for pipeline analysis.
|
||||||
retokenizes (bool): Whether the component changes the tokenization.
|
retokenizes (bool): Whether the component changes the tokenization.
|
||||||
Used for pipeline analysis.
|
Used for pipeline analysis.
|
||||||
func (Optional[Callable]): Factory function if not used as a decorator.
|
func (Optional[Callable[[Doc], Doc]): Factory function if not used as a decorator.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/language#component
|
DOCS: https://spacy.io/api/language#component
|
||||||
"""
|
"""
|
||||||
|
@ -553,11 +552,11 @@ class Language:
|
||||||
raise ValueError(Errors.E853.format(name=name))
|
raise ValueError(Errors.E853.format(name=name))
|
||||||
component_name = name if name is not None else util.get_object_name(func)
|
component_name = name if name is not None else util.get_object_name(func)
|
||||||
|
|
||||||
def add_component(component_func: "Pipe") -> Callable:
|
def add_component(component_func: PipeCallable) -> Callable:
|
||||||
if isinstance(func, type): # function is a class
|
if isinstance(func, type): # function is a class
|
||||||
raise ValueError(Errors.E965.format(name=component_name))
|
raise ValueError(Errors.E965.format(name=component_name))
|
||||||
|
|
||||||
def factory_func(nlp, name: str) -> "Pipe":
|
def factory_func(nlp, name: str) -> PipeCallable:
|
||||||
return component_func
|
return component_func
|
||||||
|
|
||||||
internal_name = cls.get_factory_name(name)
|
internal_name = cls.get_factory_name(name)
|
||||||
|
@ -607,7 +606,7 @@ class Language:
|
||||||
print_pipe_analysis(analysis, keys=keys)
|
print_pipe_analysis(analysis, keys=keys)
|
||||||
return analysis
|
return analysis
|
||||||
|
|
||||||
def get_pipe(self, name: str) -> "Pipe":
|
def get_pipe(self, name: str) -> PipeCallable:
|
||||||
"""Get a pipeline component for a given component name.
|
"""Get a pipeline component for a given component name.
|
||||||
|
|
||||||
name (str): Name of pipeline component to get.
|
name (str): Name of pipeline component to get.
|
||||||
|
@ -628,7 +627,7 @@ class Language:
|
||||||
config: Dict[str, Any] = SimpleFrozenDict(),
|
config: Dict[str, Any] = SimpleFrozenDict(),
|
||||||
raw_config: Optional[Config] = None,
|
raw_config: Optional[Config] = None,
|
||||||
validate: bool = True,
|
validate: bool = True,
|
||||||
) -> "Pipe":
|
) -> PipeCallable:
|
||||||
"""Create a pipeline component. Mostly used internally. To create and
|
"""Create a pipeline component. Mostly used internally. To create and
|
||||||
add a component to the pipeline, you can use nlp.add_pipe.
|
add a component to the pipeline, you can use nlp.add_pipe.
|
||||||
|
|
||||||
|
@ -640,7 +639,7 @@ class Language:
|
||||||
raw_config (Optional[Config]): Internals: the non-interpolated config.
|
raw_config (Optional[Config]): Internals: the non-interpolated config.
|
||||||
validate (bool): Whether to validate the component config against the
|
validate (bool): Whether to validate the component config against the
|
||||||
arguments and types expected by the factory.
|
arguments and types expected by the factory.
|
||||||
RETURNS (Pipe): The pipeline component.
|
RETURNS (Callable[[Doc], Doc]): The pipeline component.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/language#create_pipe
|
DOCS: https://spacy.io/api/language#create_pipe
|
||||||
"""
|
"""
|
||||||
|
@ -695,13 +694,13 @@ class Language:
|
||||||
|
|
||||||
def create_pipe_from_source(
|
def create_pipe_from_source(
|
||||||
self, source_name: str, source: "Language", *, name: str
|
self, source_name: str, source: "Language", *, name: str
|
||||||
) -> Tuple["Pipe", str]:
|
) -> Tuple[PipeCallable, str]:
|
||||||
"""Create a pipeline component by copying it from an existing model.
|
"""Create a pipeline component by copying it from an existing model.
|
||||||
|
|
||||||
source_name (str): Name of the component in the source pipeline.
|
source_name (str): Name of the component in the source pipeline.
|
||||||
source (Language): The source nlp object to copy from.
|
source (Language): The source nlp object to copy from.
|
||||||
name (str): Optional alternative name to use in current pipeline.
|
name (str): Optional alternative name to use in current pipeline.
|
||||||
RETURNS (Tuple[Callable, str]): The component and its factory name.
|
RETURNS (Tuple[Callable[[Doc], Doc], str]): The component and its factory name.
|
||||||
"""
|
"""
|
||||||
# Check source type
|
# Check source type
|
||||||
if not isinstance(source, Language):
|
if not isinstance(source, Language):
|
||||||
|
@ -740,7 +739,7 @@ class Language:
|
||||||
config: Dict[str, Any] = SimpleFrozenDict(),
|
config: Dict[str, Any] = SimpleFrozenDict(),
|
||||||
raw_config: Optional[Config] = None,
|
raw_config: Optional[Config] = None,
|
||||||
validate: bool = True,
|
validate: bool = True,
|
||||||
) -> "Pipe":
|
) -> PipeCallable:
|
||||||
"""Add a component to the processing pipeline. Valid components are
|
"""Add a component to the processing pipeline. Valid components are
|
||||||
callables that take a `Doc` object, modify it and return it. Only one
|
callables that take a `Doc` object, modify it and return it. Only one
|
||||||
of before/after/first/last can be set. Default behaviour is "last".
|
of before/after/first/last can be set. Default behaviour is "last".
|
||||||
|
@ -763,7 +762,7 @@ class Language:
|
||||||
raw_config (Optional[Config]): Internals: the non-interpolated config.
|
raw_config (Optional[Config]): Internals: the non-interpolated config.
|
||||||
validate (bool): Whether to validate the component config against the
|
validate (bool): Whether to validate the component config against the
|
||||||
arguments and types expected by the factory.
|
arguments and types expected by the factory.
|
||||||
RETURNS (Pipe): The pipeline component.
|
RETURNS (Callable[[Doc], Doc]): The pipeline component.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/language#add_pipe
|
DOCS: https://spacy.io/api/language#add_pipe
|
||||||
"""
|
"""
|
||||||
|
@ -869,7 +868,7 @@ class Language:
|
||||||
*,
|
*,
|
||||||
config: Dict[str, Any] = SimpleFrozenDict(),
|
config: Dict[str, Any] = SimpleFrozenDict(),
|
||||||
validate: bool = True,
|
validate: bool = True,
|
||||||
) -> "Pipe":
|
) -> PipeCallable:
|
||||||
"""Replace a component in the pipeline.
|
"""Replace a component in the pipeline.
|
||||||
|
|
||||||
name (str): Name of the component to replace.
|
name (str): Name of the component to replace.
|
||||||
|
@ -878,7 +877,7 @@ class Language:
|
||||||
component. Will be merged with default config, if available.
|
component. Will be merged with default config, if available.
|
||||||
validate (bool): Whether to validate the component config against the
|
validate (bool): Whether to validate the component config against the
|
||||||
arguments and types expected by the factory.
|
arguments and types expected by the factory.
|
||||||
RETURNS (Pipe): The new pipeline component.
|
RETURNS (Callable[[Doc], Doc]): The new pipeline component.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/language#replace_pipe
|
DOCS: https://spacy.io/api/language#replace_pipe
|
||||||
"""
|
"""
|
||||||
|
@ -930,11 +929,11 @@ class Language:
|
||||||
init_cfg = self._config["initialize"]["components"].pop(old_name)
|
init_cfg = self._config["initialize"]["components"].pop(old_name)
|
||||||
self._config["initialize"]["components"][new_name] = init_cfg
|
self._config["initialize"]["components"][new_name] = init_cfg
|
||||||
|
|
||||||
def remove_pipe(self, name: str) -> Tuple[str, "Pipe"]:
|
def remove_pipe(self, name: str) -> Tuple[str, PipeCallable]:
|
||||||
"""Remove a component from the pipeline.
|
"""Remove a component from the pipeline.
|
||||||
|
|
||||||
name (str): Name of the component to remove.
|
name (str): Name of the component to remove.
|
||||||
RETURNS (tuple): A `(name, component)` tuple of the removed component.
|
RETURNS (Tuple[str, Callable[[Doc], Doc]]): A `(name, component)` tuple of the removed component.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/language#remove_pipe
|
DOCS: https://spacy.io/api/language#remove_pipe
|
||||||
"""
|
"""
|
||||||
|
@ -1349,15 +1348,15 @@ class Language:
|
||||||
|
|
||||||
def set_error_handler(
|
def set_error_handler(
|
||||||
self,
|
self,
|
||||||
error_handler: Callable[[str, "Pipe", List[Doc], Exception], NoReturn],
|
error_handler: Callable[[str, PipeCallable, List[Doc], Exception], NoReturn],
|
||||||
):
|
):
|
||||||
"""Set an error handler object for all the components in the pipeline that implement
|
"""Set an error handler object for all the components in the pipeline
|
||||||
a set_error_handler function.
|
that implement a set_error_handler function.
|
||||||
|
|
||||||
error_handler (Callable[[str, Pipe, List[Doc], Exception], NoReturn]):
|
error_handler (Callable[[str, Callable[[Doc], Doc], List[Doc], Exception], NoReturn]):
|
||||||
Function that deals with a failing batch of documents. This callable function should take in
|
Function that deals with a failing batch of documents. This callable
|
||||||
the component's name, the component itself, the offending batch of documents, and the exception
|
function should take in the component's name, the component itself,
|
||||||
that was thrown.
|
the offending batch of documents, and the exception that was thrown.
|
||||||
DOCS: https://spacy.io/api/language#set_error_handler
|
DOCS: https://spacy.io/api/language#set_error_handler
|
||||||
"""
|
"""
|
||||||
self.default_error_handler = error_handler
|
self.default_error_handler = error_handler
|
||||||
|
|
|
@ -838,8 +838,8 @@ def test_textcat_loss(multi_label: bool, expected_loss: float):
|
||||||
textcat = nlp.add_pipe("textcat_multilabel")
|
textcat = nlp.add_pipe("textcat_multilabel")
|
||||||
else:
|
else:
|
||||||
textcat = nlp.add_pipe("textcat")
|
textcat = nlp.add_pipe("textcat")
|
||||||
textcat.initialize(lambda: train_examples)
|
|
||||||
assert isinstance(textcat, TextCategorizer)
|
assert isinstance(textcat, TextCategorizer)
|
||||||
|
textcat.initialize(lambda: train_examples)
|
||||||
scores = textcat.model.ops.asarray(
|
scores = textcat.model.ops.asarray(
|
||||||
[[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 1.0, 1.0]], dtype="f" # type: ignore
|
[[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 1.0, 1.0]], dtype="f" # type: ignore
|
||||||
)
|
)
|
||||||
|
|
|
@ -51,8 +51,7 @@ from . import about
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
# This lets us add type hints for mypy etc. without causing circular imports
|
# This lets us add type hints for mypy etc. without causing circular imports
|
||||||
from .language import Language # noqa: F401
|
from .language import Language, PipeCallable # noqa: F401
|
||||||
from .pipeline import Pipe # noqa: F401
|
|
||||||
from .tokens import Doc, Span # noqa: F401
|
from .tokens import Doc, Span # noqa: F401
|
||||||
from .vocab import Vocab # noqa: F401
|
from .vocab import Vocab # noqa: F401
|
||||||
|
|
||||||
|
@ -1642,9 +1641,9 @@ def check_bool_env_var(env_var: str) -> bool:
|
||||||
|
|
||||||
def _pipe(
|
def _pipe(
|
||||||
docs: Iterable["Doc"],
|
docs: Iterable["Doc"],
|
||||||
proc: "Pipe",
|
proc: "PipeCallable",
|
||||||
name: str,
|
name: str,
|
||||||
default_error_handler: Callable[[str, "Pipe", List["Doc"], Exception], NoReturn],
|
default_error_handler: Callable[[str, "PipeCallable", List["Doc"], Exception], NoReturn],
|
||||||
kwargs: Mapping[str, Any],
|
kwargs: Mapping[str, Any],
|
||||||
) -> Iterator["Doc"]:
|
) -> Iterator["Doc"]:
|
||||||
if hasattr(proc, "pipe"):
|
if hasattr(proc, "pipe"):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user