Error handling in nlp.pipe (#6817)

* add error handler for pipe methods

* add unit tests

* remove pipe method that are the same as their base class

* have Language keep track of a default error handler

* cleanup

* formatting

* small refactor

* add documentation
This commit is contained in:
Sofie Van Landeghem 2021-01-29 01:51:21 +01:00 committed by GitHub
parent cc18f3f23c
commit 837a4f53c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 323 additions and 200 deletions

View File

@ -8,7 +8,7 @@ from contextlib import contextmanager
from copy import deepcopy
from pathlib import Path
import warnings
from thinc.api import Model, get_current_ops, Config, Optimizer
from thinc.api import get_current_ops, Config, Optimizer
import srsly
import multiprocessing as mp
from itertools import chain, cycle
@ -20,7 +20,7 @@ from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis
from .training import Example, validate_examples
from .training.initialize import init_vocab, init_tok2vec
from .scorer import Scorer
from .util import registry, SimpleFrozenList, _pipe
from .util import registry, SimpleFrozenList, _pipe, raise_error
from .util import SimpleFrozenDict, combine_score_weights, CONFIG_SECTION_ORDER
from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS
from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES
@ -176,6 +176,7 @@ class Language:
create_tokenizer = registry.resolve(tokenizer_cfg)["tokenizer"]
self.tokenizer = create_tokenizer(self)
self.batch_size = batch_size
self.default_error_handler = raise_error
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
@ -981,11 +982,16 @@ class Language:
continue
if not hasattr(proc, "__call__"):
raise ValueError(Errors.E003.format(component=type(proc), name=name))
error_handler = self.default_error_handler
if hasattr(proc, "get_error_handler"):
error_handler = proc.get_error_handler()
try:
doc = proc(doc, **component_cfg.get(name, {}))
except KeyError as e:
# This typically happens if a component is not initialized
raise ValueError(Errors.E109.format(name=name)) from e
except Exception as e:
error_handler(name, proc, [doc], e)
if doc is None:
raise ValueError(Errors.E005.format(name=name))
return doc
@ -1274,6 +1280,26 @@ class Language:
self._optimizer = self.create_optimizer()
return self._optimizer
def set_error_handler(
self,
error_handler: Callable[
[str, Callable[[Doc], Doc], List[Doc], Exception], None
],
):
"""Set an error handler object for all the components in the pipeline that implement
a set_error_handler function.
error_handler (Callable[[str, Callable[[Doc], Doc], List[Doc], Exception], None]):
Function that deals with a failing batch of documents. This callable function should take in
the component's name, the component itself, the offending batch of documents, and the exception
that was thrown.
DOCS: https://nightly.spacy.io/api/language#set_error_handler
"""
self.default_error_handler = error_handler
for name, pipe in self.pipeline:
if hasattr(pipe, "set_error_handler"):
pipe.set_error_handler(error_handler)
def evaluate(
self,
examples: Iterable[Example],
@ -1293,6 +1319,7 @@ class Language:
arguments for specific components.
scorer_cfg (dict): An optional dictionary with extra keyword arguments
for the scorer.
RETURNS (Scorer): The scorer containing the evaluation results.
DOCS: https://nightly.spacy.io/api/language#evaluate
@ -1317,7 +1344,14 @@ class Language:
kwargs = component_cfg.get(name, {})
kwargs.setdefault("batch_size", batch_size)
for doc, eg in zip(
_pipe((eg.predicted for eg in examples), pipe, kwargs), examples
_pipe(
(eg.predicted for eg in examples),
proc=pipe,
name=name,
default_error_handler=self.default_error_handler,
kwargs=kwargs,
),
examples,
):
eg.predicted = doc
end_time = timer()
@ -1422,7 +1456,13 @@ class Language:
kwargs = component_cfg.get(name, {})
# Allow component_cfg to overwrite the top-level kwargs.
kwargs.setdefault("batch_size", batch_size)
f = functools.partial(_pipe, proc=proc, kwargs=kwargs)
f = functools.partial(
_pipe,
proc=proc,
name=name,
kwargs=kwargs,
default_error_handler=self.default_error_handler,
)
pipes.append(f)
if n_process != 1:

View File

@ -96,12 +96,25 @@ class AttributeRuler(Pipe):
DOCS: https://nightly.spacy.io/api/attributeruler#call
"""
error_handler = self.get_error_handler()
try:
matches = self.match(doc)
self.set_annotations(doc, matches)
return doc
except Exception as e:
error_handler(self.name, self, [doc], e)
def match(self, doc: Doc):
matches = self.matcher(doc, allow_missing=True)
# Sort by the attribute ID, so that later rules have precendence
matches = [
(int(self.vocab.strings[m_id]), m_id, s, e) for m_id, s, e in matches
]
matches.sort()
return matches
def set_annotations(self, doc, matches):
"""Modify the document in place"""
for attr_id, match_id, start, end in matches:
span = Span(doc, start, end, label=match_id)
attrs = self.attrs[attr_id]
@ -121,7 +134,7 @@ class AttributeRuler(Pipe):
)
) from None
set_token_attrs(span[index], attrs)
return doc
def load_from_tag_map(
self, tag_map: Dict[str, Dict[Union[int, str], Union[int, str]]]

View File

@ -1,6 +1,6 @@
from itertools import islice
from typing import Optional, Iterable, Callable, Dict, Iterator, Union, List
from typing import Optional, Iterable, Callable, Dict, Union, List
from pathlib import Path
from itertools import islice
import srsly
import random
from thinc.api import CosineDistance, Model, Optimizer, Config
@ -276,34 +276,6 @@ class EntityLinker(TrainablePipe):
loss = loss / len(entity_encodings)
return loss, gradients
def __call__(self, doc: Doc) -> Doc:
"""Apply the pipe to a Doc.
doc (Doc): The document to process.
RETURNS (Doc): The processed Doc.
DOCS: https://nightly.spacy.io/api/entitylinker#call
"""
kb_ids = self.predict([doc])
self.set_annotations([doc], kb_ids)
return doc
def pipe(self, stream: Iterable[Doc], *, batch_size: int = 128) -> Iterator[Doc]:
"""Apply the pipe to a stream of documents. This usually happens under
the hood when the nlp object is called on a text and all components are
applied to the Doc.
stream (Iterable[Doc]): A stream of documents.
batch_size (int): The number of documents to buffer.
YIELDS (Doc): Processed documents in order.
DOCS: https://nightly.spacy.io/api/entitylinker#pipe
"""
for docs in util.minibatch(stream, size=batch_size):
kb_ids = self.predict(docs)
self.set_annotations(docs, kb_ids)
yield from docs
def predict(self, docs: Iterable[Doc]) -> List[str]:
"""Apply the pipeline's model to a batch of docs, without modifying them.
Returns the KB IDs for each entity in each doc, including NIL if there is

View File

@ -135,12 +135,25 @@ class EntityRuler(Pipe):
DOCS: https://nightly.spacy.io/api/entityruler#call
"""
error_handler = self.get_error_handler()
try:
matches = self.match(doc)
self.set_annotations(doc, matches)
return doc
except Exception as e:
error_handler(self.name, self, [doc], e)
def match(self, doc: Doc):
matches = list(self.matcher(doc)) + list(self.phrase_matcher(doc))
matches = set(
[(m_id, start, end) for m_id, start, end in matches if start != end]
)
get_sort_key = lambda m: (m[2] - m[1], -m[1])
matches = sorted(matches, key=get_sort_key, reverse=True)
return matches
def set_annotations(self, doc, matches):
"""Modify the document in place"""
entities = list(doc.ents)
new_entities = []
seen_tokens = set()
@ -163,7 +176,6 @@ class EntityRuler(Pipe):
]
seen_tokens.update(range(start, end))
doc.ents = entities + new_entities
return doc
@property
def labels(self) -> Tuple[str, ...]:

View File

@ -23,11 +23,7 @@ from .. import util
default_score_weights={"lemma_acc": 1.0},
)
def make_lemmatizer(
nlp: Language,
model: Optional[Model],
name: str,
mode: str,
overwrite: bool = False,
nlp: Language, model: Optional[Model], name: str, mode: str, overwrite: bool = False
):
return Lemmatizer(nlp.vocab, model, name, mode=mode, overwrite=overwrite)
@ -107,10 +103,14 @@ class Lemmatizer(Pipe):
"""
if not self._validated:
self._validate_tables(Errors.E1004)
for token in doc:
if self.overwrite or token.lemma == 0:
token.lemma_ = self.lemmatize(token)[0]
return doc
error_handler = self.get_error_handler()
try:
for token in doc:
if self.overwrite or token.lemma == 0:
token.lemma_ = self.lemmatize(token)[0]
return doc
except Exception as e:
error_handler(self.name, self, [doc], e)
def initialize(
self,
@ -154,21 +154,6 @@ class Lemmatizer(Pipe):
)
self._validated = True
def pipe(self, stream: Iterable[Doc], *, batch_size: int = 128) -> Iterator[Doc]:
"""Apply the pipe to a stream of documents. This usually happens under
the hood when the nlp object is called on a text and all components are
applied to the Doc.
stream (Iterable[Doc]): A stream of documents.
batch_size (int): The number of documents to buffer.
YIELDS (Doc): Processed documents in order.
DOCS: https://nightly.spacy.io/api/lemmatizer#pipe
"""
for doc in stream:
doc = self(doc)
yield doc
def lookup_lemmatize(self, token: Token) -> List[str]:
"""Lemmatize using a lookup-based approach.

View File

@ -1,13 +1,14 @@
# cython: infer_types=True, profile=True
import warnings
from typing import Optional, Tuple, Iterable, Iterator, Callable, Union, Dict
import srsly
import warnings
from ..tokens.doc cimport Doc
from ..training import Example
from ..errors import Errors, Warnings
from ..language import Language
from ..util import raise_error
cdef class Pipe:
"""This class is a base class and not instantiated directly. It provides
@ -48,9 +49,13 @@ cdef class Pipe:
DOCS: https://nightly.spacy.io/api/pipe#pipe
"""
error_handler = self.get_error_handler()
for doc in stream:
doc = self(doc)
yield doc
try:
doc = self(doc)
yield doc
except Exception as e:
error_handler(self.name, self, [doc], e)
def initialize(self, get_examples: Callable[[], Iterable[Example]], *, nlp: Language=None):
"""Initialize the pipe. For non-trainable components, this method
@ -98,6 +103,30 @@ cdef class Pipe:
if not self.labels or list(self.labels) == [""]:
raise ValueError(Errors.E143.format(name=self.name))
def set_error_handler(self, error_handler: Callable) -> None:
"""Set an error handler function.
error_handler (Callable[[str, Callable[[Doc], Doc], List[Doc], Exception], None]):
Function that deals with a failing batch of documents. This callable function should take in
the component's name, the component itself, the offending batch of documents, and the exception
that was thrown.
DOCS: https://nightly.spacy.io/api/pipe#set_error_handler
"""
self.error_handler = error_handler
def get_error_handler(self) -> Optional[Callable]:
"""Retrieve the error handler function.
RETURNS (Callable): The error handler, or if it's not set a default function that just reraises.
DOCS: https://nightly.spacy.io/api/pipe#get_error_handler
"""
if hasattr(self, "error_handler"):
return self.error_handler
return raise_error
def deserialize_config(path):
if path.exists():
return srsly.read_json(path)

View File

@ -1,16 +1,14 @@
# cython: infer_types=True, profile=True, binding=True
import srsly
from typing import Optional, List
import srsly
from ..tokens.doc cimport Doc
from .pipe import Pipe
from ..language import Language
from ..scorer import Scorer
from ..training import validate_examples
from .. import util
@Language.factory(
"sentencizer",
assigns=["token.is_sent_start", "doc.sents"],
@ -66,6 +64,14 @@ class Sentencizer(Pipe):
DOCS: https://nightly.spacy.io/api/sentencizer#call
"""
error_handler = self.get_error_handler()
try:
self._call(doc)
return doc
except Exception as e:
error_handler(self.name, self, [doc], e)
def _call(self, doc):
start = 0
seen_period = False
for i, token in enumerate(doc):
@ -79,23 +85,6 @@ class Sentencizer(Pipe):
seen_period = True
if start < len(doc):
doc[start].is_sent_start = True
return doc
def pipe(self, stream, batch_size=128):
"""Apply the pipe to a stream of documents. This usually happens under
the hood when the nlp object is called on a text and all components are
applied to the Doc.
stream (Iterable[Doc]): A stream of documents.
batch_size (int): The number of documents to buffer.
YIELDS (Doc): Processed documents in order.
DOCS: https://nightly.spacy.io/api/sentencizer#pipe
"""
for docs in util.minibatch(stream, size=batch_size):
predictions = self.predict(docs)
self.set_annotations(docs, predictions)
yield from docs
def predict(self, docs):
"""Apply the pipe to a batch of docs, without modifying them.

View File

@ -1,5 +1,4 @@
# cython: infer_types=True, profile=True, binding=True
from typing import List
import numpy
import srsly
from thinc.api import Model, set_dropout_rate, SequenceCategoricalCrossentropy, Config
@ -95,34 +94,6 @@ class Tagger(TrainablePipe):
"""Data about the labels currently added to the component."""
return tuple(self.cfg["labels"])
def __call__(self, doc):
"""Apply the pipe to a Doc.
doc (Doc): The document to process.
RETURNS (Doc): The processed Doc.
DOCS: https://nightly.spacy.io/api/tagger#call
"""
tags = self.predict([doc])
self.set_annotations([doc], tags)
return doc
def pipe(self, stream, *, batch_size=128):
"""Apply the pipe to a stream of documents. This usually happens under
the hood when the nlp object is called on a text and all components are
applied to the Doc.
stream (Iterable[Doc]): A stream of documents.
batch_size (int): The number of documents to buffer.
YIELDS (Doc): Processed documents in order.
DOCS: https://nightly.spacy.io/api/tagger#pipe
"""
for docs in util.minibatch(stream, size=batch_size):
tag_ids = self.predict(docs)
self.set_annotations(docs, tag_ids)
yield from docs
def predict(self, docs):
"""Apply the pipeline's model to a batch of docs, without modifying them.

View File

@ -1,5 +1,5 @@
from itertools import islice
from typing import Iterable, Tuple, Optional, Dict, List, Callable, Iterator, Any
from typing import Iterable, Tuple, Optional, Dict, List, Callable, Any
from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config
from thinc.types import Floats2d
import numpy
@ -9,7 +9,6 @@ from ..language import Language
from ..training import Example, validate_examples, validate_get_examples
from ..errors import Errors
from ..scorer import Scorer
from .. import util
from ..tokens import Doc
from ..vocab import Vocab
@ -144,22 +143,6 @@ class TextCategorizer(TrainablePipe):
"""
return self.labels
def pipe(self, stream: Iterable[Doc], *, batch_size: int = 128) -> Iterator[Doc]:
"""Apply the pipe to a stream of documents. This usually happens under
the hood when the nlp object is called on a text and all components are
applied to the Doc.
stream (Iterable[Doc]): A stream of documents.
batch_size (int): The number of documents to buffer.
YIELDS (Doc): Processed documents in order.
DOCS: https://nightly.spacy.io/api/textcategorizer#pipe
"""
for docs in util.minibatch(stream, size=batch_size):
scores = self.predict(docs)
self.set_annotations(docs, scores)
yield from docs
def predict(self, docs: Iterable[Doc]):
"""Apply the pipeline's model to a batch of docs, without modifying them.

View File

@ -1,4 +1,4 @@
from typing import Iterator, Sequence, Iterable, Optional, Dict, Callable, List
from typing import Sequence, Iterable, Optional, Dict, Callable, List
from thinc.api import Model, set_dropout_rate, Optimizer, Config
from itertools import islice
@ -8,8 +8,6 @@ from ..tokens import Doc
from ..vocab import Vocab
from ..language import Language
from ..errors import Errors
from ..util import minibatch
default_model_config = """
[model]
@ -99,36 +97,6 @@ class Tok2Vec(TrainablePipe):
if isinstance(node, Tok2VecListener) and node.upstream_name in names:
self.add_listener(node, component.name)
def __call__(self, doc: Doc) -> Doc:
"""Add context-sensitive embeddings to the Doc.tensor attribute, allowing
them to be used as features by downstream components.
docs (Doc): The Doc to process.
RETURNS (Doc): The processed Doc.
DOCS: https://nightly.spacy.io/api/tok2vec#call
"""
tokvecses = self.predict([doc])
self.set_annotations([doc], tokvecses)
return doc
def pipe(self, stream: Iterator[Doc], *, batch_size: int = 128) -> Iterator[Doc]:
"""Apply the pipe to a stream of documents. This usually happens under
the hood when the nlp object is called on a text and all components are
applied to the Doc.
stream (Iterable[Doc]): A stream of documents.
batch_size (int): The number of documents to buffer.
YIELDS (Doc): Processed documents in order.
DOCS: https://nightly.spacy.io/api/tok2vec#pipe
"""
for docs in minibatch(stream, batch_size):
docs = list(docs)
tokvecses = self.predict(docs)
self.set_annotations(docs, tokvecses)
yield from docs
def predict(self, docs: Iterable[Doc]):
"""Apply the pipeline's model to a batch of docs, without modifying them.
Returns a single tensor for a batch of documents.

View File

@ -28,7 +28,7 @@ cdef class TrainablePipe(Pipe):
vocab (Vocab): The shared vocabulary.
model (thinc.api.Model): The Thinc Model powering the pipeline component.
name (str): The component instance name.
**cfg: Additonal settings and config parameters.
**cfg: Additional settings and config parameters.
DOCS: https://nightly.spacy.io/api/pipe#init
"""
@ -47,9 +47,13 @@ cdef class TrainablePipe(Pipe):
DOCS: https://nightly.spacy.io/api/pipe#call
"""
scores = self.predict([doc])
self.set_annotations([doc], scores)
return doc
error_handler = self.get_error_handler()
try:
scores = self.predict([doc])
self.set_annotations([doc], scores)
return doc
except Exception as e:
error_handler(self.name, self, [doc], e)
def pipe(self, stream: Iterable[Doc], *, batch_size: int=128) -> Iterator[Doc]:
"""Apply the pipe to a stream of documents. This usually happens under
@ -58,14 +62,21 @@ cdef class TrainablePipe(Pipe):
stream (Iterable[Doc]): A stream of documents.
batch_size (int): The number of documents to buffer.
error_handler (Callable[[str, List[Doc], Exception], Any]): Function that
deals with a failing batch of documents. The default function just reraises
the exception.
YIELDS (Doc): Processed documents in order.
DOCS: https://nightly.spacy.io/api/pipe#pipe
"""
error_handler = self.get_error_handler()
for docs in util.minibatch(stream, size=batch_size):
scores = self.predict(docs)
self.set_annotations(docs, scores)
yield from docs
try:
scores = self.predict(docs)
self.set_annotations(docs, scores)
yield from docs
except Exception as e:
error_handler(self.name, self, docs, e)
def predict(self, docs: Iterable[Doc]):
"""Apply the pipeline's model to a batch of docs, without modifying them.

View File

@ -7,7 +7,6 @@ from libcpp.vector cimport vector
from libc.string cimport memset, memcpy
from libc.stdlib cimport calloc, free
import random
from typing import Optional
import srsly
from thinc.api import set_dropout_rate, CupyOps
@ -30,7 +29,6 @@ from ..training import validate_examples, validate_get_examples
from ..errors import Errors, Warnings
from .. import util
cdef class Parser(TrainablePipe):
"""
Base class of the DependencyParser and EntityRecognizer.
@ -175,32 +173,31 @@ cdef class Parser(TrainablePipe):
with self.model.use_params(params):
yield
def __call__(self, Doc doc):
"""Apply the parser or entity recognizer, setting the annotations onto
the `Doc` object.
doc (Doc): The document to be processed.
"""
states = self.predict([doc])
self.set_annotations([doc], states)
return doc
def pipe(self, docs, *, int batch_size=256):
"""Process a stream of documents.
stream: The sequence of documents to process.
batch_size (int): Number of documents to accumulate into a working set.
error_handler (Callable[[str, List[Doc], Exception], Any]): Function that
deals with a failing batch of documents. The default function just reraises
the exception.
YIELDS (Doc): Documents, in order.
"""
cdef Doc doc
error_handler = self.get_error_handler()
for batch in util.minibatch(docs, size=batch_size):
batch_in_order = list(batch)
by_length = sorted(batch, key=lambda doc: len(doc))
for subbatch in util.minibatch(by_length, size=max(batch_size//4, 2)):
subbatch = list(subbatch)
parse_states = self.predict(subbatch)
self.set_annotations(subbatch, parse_states)
yield from batch_in_order
try:
by_length = sorted(batch, key=lambda doc: len(doc))
for subbatch in util.minibatch(by_length, size=max(batch_size//4, 2)):
subbatch = list(subbatch)
parse_states = self.predict(subbatch)
self.set_annotations(subbatch, parse_states)
yield from batch_in_order
except Exception as e:
error_handler(self.name, self, batch_in_order, e)
def predict(self, docs):
if isinstance(docs, Doc):

View File

@ -1,4 +1,6 @@
import itertools
import logging
from unittest import mock
import pytest
from spacy.language import Language
from spacy.tokens import Doc, Span
@ -6,7 +8,7 @@ from spacy.vocab import Vocab
from spacy.training import Example
from spacy.lang.en import English
from spacy.lang.de import German
from spacy.util import registry
from spacy.util import registry, ignore_error, raise_error
import spacy
from .util import add_vecs_to_vocab, assert_docs_equal
@ -161,6 +163,81 @@ def test_language_pipe_stream(nlp2, n_process, texts):
assert_docs_equal(doc, expected_doc)
def test_language_pipe_error_handler():
"""Test that the error handling of nlp.pipe works well"""
nlp = English()
nlp.add_pipe("merge_subtokens")
nlp.initialize()
texts = ["Curious to see what will happen to this text.", "And this one."]
# the pipeline fails because there's no parser
with pytest.raises(ValueError):
nlp(texts[0])
with pytest.raises(ValueError):
list(nlp.pipe(texts))
nlp.set_error_handler(raise_error)
with pytest.raises(ValueError):
list(nlp.pipe(texts))
# set explicitely to ignoring
nlp.set_error_handler(ignore_error)
docs = list(nlp.pipe(texts))
assert len(docs) == 0
nlp(texts[0])
def test_language_pipe_error_handler_custom(en_vocab):
"""Test the error handling of a custom component that has no pipe method"""
@Language.component("my_evil_component")
def evil_component(doc):
if "2" in doc.text:
raise ValueError("no dice")
return doc
def warn_error(proc_name, proc, docs, e):
from spacy.util import logger
logger.warning(f"Trouble with component {proc_name}.")
nlp = English()
nlp.add_pipe("my_evil_component")
nlp.initialize()
texts = ["TEXT 111", "TEXT 222", "TEXT 333", "TEXT 342", "TEXT 666"]
with pytest.raises(ValueError):
# the evil custom component throws an error
list(nlp.pipe(texts))
nlp.set_error_handler(warn_error)
logger = logging.getLogger("spacy")
with mock.patch.object(logger, "warning") as mock_warning:
# the errors by the evil custom component raise a warning for each bad batch
docs = list(nlp.pipe(texts))
mock_warning.assert_called()
assert mock_warning.call_count == 2
assert len(docs) + mock_warning.call_count == len(texts)
assert [doc.text for doc in docs] == ["TEXT 111", "TEXT 333", "TEXT 666"]
def test_language_pipe_error_handler_pipe(en_vocab):
"""Test the error handling of a component's pipe method"""
@Language.component("my_sentences")
def perhaps_set_sentences(doc):
if not doc.text.startswith("4"):
doc[-1].is_sent_start = True
return doc
texts = [f"{str(i)} is enough. Done" for i in range(100)]
nlp = English()
nlp.add_pipe("my_sentences")
entity_linker = nlp.add_pipe("entity_linker", config={"entity_vector_length": 3})
entity_linker.kb.add_entity(entity="Q1", freq=12, entity_vector=[1, 2, 3])
nlp.initialize()
with pytest.raises(ValueError):
# the entity linker requires sentence boundaries, will throw an error otherwise
docs = list(nlp.pipe(texts, batch_size=10))
nlp.set_error_handler(ignore_error)
docs = list(nlp.pipe(texts, batch_size=10))
# we lose/ignore the failing 0-9 and 40-49 batches
assert len(docs) == 80
def test_language_from_config_before_after_init():
name = "test_language_from_config_before_after_init"
ran_before = False

View File

@ -1420,15 +1420,28 @@ def check_bool_env_var(env_var: str) -> bool:
return bool(value)
def _pipe(docs, proc, kwargs):
def _pipe(docs, proc, name, default_error_handler, kwargs):
if hasattr(proc, "pipe"):
yield from proc.pipe(docs, **kwargs)
else:
# We added some args for pipe that __call__ doesn't expect.
kwargs = dict(kwargs)
error_handler = default_error_handler
if hasattr(proc, "get_error_handler"):
error_handler = proc.get_error_handler()
for arg in ["batch_size"]:
if arg in kwargs:
kwargs.pop(arg)
for doc in docs:
doc = proc(doc, **kwargs)
yield doc
try:
doc = proc(doc, **kwargs)
yield doc
except Exception as e:
error_handler(name, proc, [doc], e)
def raise_error(proc_name, proc, docs, e):
raise e
def ignore_error(proc_name, proc, docs, e):
pass

View File

@ -203,6 +203,28 @@ more efficient than processing texts one-by-one.
| `n_process` <Tag variant="new">2.2.2</Tag> | Number of processors to use. Defaults to `1`. ~~int~~ |
| **YIELDS** | Documents in the order of the original text. ~~Doc~~ |
## Language.set_error_handler {#set_error_handler tag="method"}
Define a callback that will be invoked when an error is thrown during processing
of one or more documents. Specifically, this function will call
[`set_error_handler`](/api/pipe#set_error_handler) on all the pipeline
components that define that function. The error handler will be invoked with the
original component's name, the component itself, the list of documents that was
being processed, and the original error.
> #### Example
>
> ```python
> def warn_error(proc_name, proc, docs, e):
> print(f"An error occurred when applying component {proc_name}.")
>
> nlp.set_error_handler(warn_error)
> ```
| Name | Description |
| --------------- | -------------------------------------------------------------------------------------------------------------- |
| `error_handler` | A function that performs custom error handling. ~~Callable[[str, Callable[[Doc], Doc], List[Doc], Exception]~~ |
## Language.initialize {#initialize tag="method" new="3"}
Initialize the pipeline for training and return an

View File

@ -100,6 +100,47 @@ applied to the `Doc` in order. Both [`__call__`](/api/pipe#call) and
| `batch_size` | The number of documents to buffer. Defaults to `128`. ~~int~~ |
| **YIELDS** | The processed documents in order. ~~Doc~~ |
## TrainablePipe.set_error_handler {#set_error_handler tag="method"}
Define a callback that will be invoked when an error is thrown during processing
of one or more documents with either [`__call__`](/api/pipe#call) or
[`pipe`](/api/pipe#pipe). The error handler will be invoked with the original
component's name, the component itself, the list of documents that was being
processed, and the original error.
> #### Example
>
> ```python
> def warn_error(proc_name, proc, docs, e):
> print(f"An error occurred when applying component {proc_name}.")
>
> pipe = nlp.add_pipe("ner")
> pipe.set_error_handler(warn_error)
> ```
| Name | Description |
| --------------- | -------------------------------------------------------------------------------------------------------------- |
| `error_handler` | A function that performs custom error handling. ~~Callable[[str, Callable[[Doc], Doc], List[Doc], Exception]~~ |
## TrainablePipe.get_error_handler {#get_error_handler tag="method"}
Retrieve the callback that performs error handling for this component's
[`__call__`](/api/pipe#call) and [`pipe`](/api/pipe#pipe) methods. If no custom
function was previously defined with
[`set_error_handler`](/api/pipe#set_error_handler), a default function is
returned that simply reraises the exception.
> #### Example
>
> ```python
> pipe = nlp.add_pipe("ner")
> error_handler = pipe.get_error_handler()
> ```
| Name | Description |
| ----------- | ---------------------------------------------------------------------------------------------------------------- |
| **RETURNS** | The function that performs custom error handling. ~~Callable[[str, Callable[[Doc], Doc], List[Doc], Exception]~~ |
## TrainablePipe.initialize {#initialize tag="method" new="3"}
Initialize the component for training. `get_examples` should be a function that
@ -190,14 +231,14 @@ predictions and gold-standard annotations, and update the component's model.
> losses = pipe.update(examples, sgd=optimizer)
> ```
| Name | Description |
| ----------------- | ---------------------------------------------------------------------------------------------------------------------------------- |
| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ |
| _keyword-only_ | |
| `drop` | The dropout rate. ~~float~~ |
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
| Name | Description |
| -------------- | ------------------------------------------------------------------------------------------------------------------------ |
| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ |
| _keyword-only_ | |
| `drop` | The dropout rate. ~~float~~ |
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
## TrainablePipe.rehearse {#rehearse tag="method,experimental" new="3"}