mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 09:26:27 +03:00
0ec9a696e6
* Enable Cython<->Python bindings for `Pipe` and `TrainablePipe` methods * `pipes_with_nvtx_range`: Skip hooking methods whose signature cannot be ascertained When loading pipelines from a config file, the arguments passed to individual pipeline components is validated by `pydantic` during init. For this, the validation model attempts to parse the function signature of the component's c'tor/entry point so that it can check if all mandatory parameters are present in the config file. When using the `models_and_pipes_with_nvtx_range` as a `after_pipeline_creation` callback, the methods of all pipeline components get replaced by a NVTX range wrapper **before** the above-mentioned validation takes place. This can be problematic for components that are implemented as Cython extension types - if the extension type is not compiled with Python bindings for its methods, they will have no signatures at runtime. This resulted in `pydantic` matching the *wrapper's* parameters with the those in the config and raising errors. To avoid this, we now skip applying the wrapper to any (Cython) methods that do not have signatures.
150 lines
5.5 KiB
Cython
150 lines
5.5 KiB
Cython
# cython: infer_types=True, profile=True, binding=True
|
|
from typing import Optional, Tuple, Iterable, Iterator, Callable, Union, Dict
|
|
import srsly
|
|
import warnings
|
|
|
|
from ..tokens.doc cimport Doc
|
|
|
|
from ..training import Example
|
|
from ..errors import Errors, Warnings
|
|
from ..language import Language
|
|
from ..util import raise_error
|
|
|
|
cdef class Pipe:
|
|
"""This class is a base class and not instantiated directly. It provides
|
|
an interface for pipeline components to implement.
|
|
Trainable pipeline components like the EntityRecognizer or TextCategorizer
|
|
should inherit from the subclass 'TrainablePipe'.
|
|
|
|
DOCS: https://spacy.io/api/pipe
|
|
"""
|
|
|
|
@classmethod
|
|
def __init_subclass__(cls, **kwargs):
|
|
"""Raise a warning if an inheriting class implements 'begin_training'
|
|
(from v2) instead of the new 'initialize' method (from v3)"""
|
|
if hasattr(cls, "begin_training"):
|
|
warnings.warn(Warnings.W088.format(name=cls.__name__))
|
|
|
|
def __call__(self, Doc doc) -> Doc:
|
|
"""Apply the pipe to one document. The document is modified in place,
|
|
and returned. This usually happens under the hood when the nlp object
|
|
is called on a text and all components are applied to the Doc.
|
|
|
|
doc (Doc): The Doc to process.
|
|
RETURNS (Doc): The processed Doc.
|
|
|
|
DOCS: https://spacy.io/api/pipe#call
|
|
"""
|
|
raise NotImplementedError(Errors.E931.format(parent="Pipe", method="__call__", name=self.name))
|
|
|
|
def pipe(self, stream: Iterable[Doc], *, batch_size: int=128) -> Iterator[Doc]:
|
|
"""Apply the pipe to a stream of documents. This usually happens under
|
|
the hood when the nlp object is called on a text and all components are
|
|
applied to the Doc.
|
|
|
|
stream (Iterable[Doc]): A stream of documents.
|
|
batch_size (int): The number of documents to buffer.
|
|
YIELDS (Doc): Processed documents in order.
|
|
|
|
DOCS: https://spacy.io/api/pipe#pipe
|
|
"""
|
|
error_handler = self.get_error_handler()
|
|
for doc in stream:
|
|
try:
|
|
doc = self(doc)
|
|
yield doc
|
|
except Exception as e:
|
|
error_handler(self.name, self, [doc], e)
|
|
|
|
def initialize(self, get_examples: Callable[[], Iterable[Example]], *, nlp: Language=None):
|
|
"""Initialize the pipe. For non-trainable components, this method
|
|
is optional. For trainable components, which should inherit
|
|
from the subclass TrainablePipe, the provided data examples
|
|
should be used to ensure that the internal model is initialized
|
|
properly and all input/output dimensions throughout the network are
|
|
inferred.
|
|
|
|
get_examples (Callable[[], Iterable[Example]]): Function that
|
|
returns a representative sample of gold-standard Example objects.
|
|
nlp (Language): The current nlp object the component is part of.
|
|
|
|
DOCS: https://spacy.io/api/pipe#initialize
|
|
"""
|
|
pass
|
|
|
|
def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Union[float, Dict[str, float]]]:
|
|
"""Score a batch of examples.
|
|
|
|
examples (Iterable[Example]): The examples to score.
|
|
RETURNS (Dict[str, Any]): The scores.
|
|
|
|
DOCS: https://spacy.io/api/pipe#score
|
|
"""
|
|
if hasattr(self, "scorer") and self.scorer is not None:
|
|
scorer_kwargs = {}
|
|
# use default settings from cfg (e.g., threshold)
|
|
if hasattr(self, "cfg") and isinstance(self.cfg, dict):
|
|
scorer_kwargs.update(self.cfg)
|
|
# override self.cfg["labels"] with self.labels
|
|
if hasattr(self, "labels"):
|
|
scorer_kwargs["labels"] = self.labels
|
|
# override with kwargs settings
|
|
scorer_kwargs.update(kwargs)
|
|
return self.scorer(examples, **scorer_kwargs)
|
|
return {}
|
|
|
|
@property
|
|
def is_trainable(self) -> bool:
|
|
return False
|
|
|
|
@property
|
|
def labels(self) -> Tuple[str, ...]:
|
|
return tuple()
|
|
|
|
@property
|
|
def hide_labels(self) -> bool:
|
|
return False
|
|
|
|
@property
|
|
def label_data(self):
|
|
"""Optional JSON-serializable data that would be sufficient to recreate
|
|
the label set if provided to the `pipe.initialize()` method.
|
|
"""
|
|
return None
|
|
|
|
def _require_labels(self) -> None:
|
|
"""Raise an error if this component has no labels defined."""
|
|
if not self.labels or list(self.labels) == [""]:
|
|
raise ValueError(Errors.E143.format(name=self.name))
|
|
|
|
def set_error_handler(self, error_handler: Callable) -> None:
|
|
"""Set an error handler function.
|
|
|
|
error_handler (Callable[[str, Callable[[Doc], Doc], List[Doc], Exception], None]):
|
|
Function that deals with a failing batch of documents. This callable function should take in
|
|
the component's name, the component itself, the offending batch of documents, and the exception
|
|
that was thrown.
|
|
|
|
DOCS: https://spacy.io/api/pipe#set_error_handler
|
|
"""
|
|
self.error_handler = error_handler
|
|
|
|
def get_error_handler(self) -> Callable:
|
|
"""Retrieve the error handler function.
|
|
|
|
RETURNS (Callable): The error handler, or if it's not set a default function that just reraises.
|
|
|
|
DOCS: https://spacy.io/api/pipe#get_error_handler
|
|
"""
|
|
if hasattr(self, "error_handler"):
|
|
return self.error_handler
|
|
return raise_error
|
|
|
|
|
|
def deserialize_config(path):
|
|
if path.exists():
|
|
return srsly.read_json(path)
|
|
else:
|
|
return {}
|