Tidy up pipes (#5906)

* Tidy up pipes

* Fix init, defaults and raise custom errors

* Update docs

* Update docs [ci skip]

* Apply suggestions from code review

Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com>

* Tidy up error handling and validation, fix consistency

* Simplify get_examples check

* Remove unused import [ci skip]

Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com>
This commit is contained in:
Ines Montani 2020-08-11 23:29:31 +02:00 committed by GitHub
parent b7ec06e331
commit 950832f087
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 354 additions and 209 deletions

View File

@ -295,7 +295,11 @@ def train_while_improving(
nlp.rehearse(raw_batch, sgd=optimizer, losses=losses, exclude=exclude) nlp.rehearse(raw_batch, sgd=optimizer, losses=losses, exclude=exclude)
# TODO: refactor this so we don't have to run it separately in here # TODO: refactor this so we don't have to run it separately in here
for name, proc in nlp.pipeline: for name, proc in nlp.pipeline:
if name not in exclude and hasattr(proc, "model"): if (
name not in exclude
and hasattr(proc, "model")
and proc.model not in (True, False, None)
):
proc.model.finish_update(optimizer) proc.model.finish_update(optimizer)
optimizer.step_schedules() optimizer.step_schedules()
if not (step % eval_frequency): if not (step % eval_frequency):

View File

@ -482,6 +482,15 @@ class Errors:
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].") E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
# TODO: fix numbering after merging develop into master # TODO: fix numbering after merging develop into master
E930 = ("Received invalid get_examples callback in {name}.begin_training. "
"Expected function that returns an iterable of Example objects but "
"got: {obj}")
E931 = ("Encountered Pipe subclass without Pipe.{method} method in component "
"'{name}'. If the component is trainable and you want to use this "
"method, make sure it's overwritten on the subclass. If your "
"component isn't trainable, add a method that does nothing or "
"don't use the Pipe base class.")
E940 = ("Found NaN values in scores.")
E941 = ("Can't find model '{name}'. It looks like you're trying to load a " E941 = ("Can't find model '{name}'. It looks like you're trying to load a "
"model from a shortcut, which is deprecated as of spaCy v3.0. To " "model from a shortcut, which is deprecated as of spaCy v3.0. To "
"load the model, use its full name instead:\n\n" "load the model, use its full name instead:\n\n"
@ -578,8 +587,7 @@ class Errors:
"but received None.") "but received None.")
E977 = ("Can not compare a MorphAnalysis with a string object. " E977 = ("Can not compare a MorphAnalysis with a string object. "
"This is likely a bug in spaCy, so feel free to open an issue.") "This is likely a bug in spaCy, so feel free to open an issue.")
E978 = ("The '{method}' method of {name} takes a list of Example objects, " E978 = ("The {name} method takes a list of Example objects, but got: {types}")
"but found {types} instead.")
E979 = ("Cannot convert {type} to an Example object.") E979 = ("Cannot convert {type} to an Example object.")
E980 = ("Each link annotation should refer to a dictionary with at most one " E980 = ("Each link annotation should refer to a dictionary with at most one "
"identifier mapping to 1.0, and all others to 0.0.") "identifier mapping to 1.0, and all others to 0.0.")

View File

@ -1,5 +1,5 @@
from .corpus import Corpus # noqa: F401 from .corpus import Corpus # noqa: F401
from .example import Example # noqa: F401 from .example import Example, validate_examples # noqa: F401
from .align import Alignment # noqa: F401 from .align import Alignment # noqa: F401
from .iob_utils import iob_to_biluo, biluo_to_iob # noqa: F401 from .iob_utils import iob_to_biluo, biluo_to_iob # noqa: F401
from .iob_utils import biluo_tags_from_offsets, offsets_from_biluo_tags # noqa: F401 from .iob_utils import biluo_tags_from_offsets, offsets_from_biluo_tags # noqa: F401

View File

@ -1,5 +1,5 @@
from collections import Iterable as IterableInstance
import warnings import warnings
import numpy import numpy
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
@ -26,6 +26,22 @@ cpdef Doc annotations2doc(vocab, tok_annot, doc_annot):
return output return output
def validate_examples(examples, method):
"""Check that a batch of examples received during processing is valid.
This function lives here to prevent circular imports.
examples (Iterable[Examples]): A batch of examples.
method (str): The method name to show in error messages.
"""
if not isinstance(examples, IterableInstance):
err = Errors.E978.format(name=method, types=type(examples))
raise TypeError(err)
wrong = set([type(eg) for eg in examples if not isinstance(eg, Example)])
if wrong:
err = Errors.E978.format(name=method, types=wrong)
raise TypeError(err)
cdef class Example: cdef class Example:
def __init__(self, Doc predicted, Doc reference, *, alignment=None): def __init__(self, Doc predicted, Doc reference, *, alignment=None):
if predicted is None: if predicted is None:
@ -263,12 +279,10 @@ def _annot2array(vocab, tok_annot, doc_annot):
values.append([vocab.morphology.add(v) for v in value]) values.append([vocab.morphology.add(v) for v in value])
else: else:
attrs.append(key) attrs.append(key)
try: if not all(isinstance(v, str) for v in value):
values.append([vocab.strings.add(v) for v in value]) types = set([type(v) for v in value])
except TypeError:
types= set([type(v) for v in value])
raise TypeError(Errors.E969.format(field=key, types=types)) from None raise TypeError(Errors.E969.format(field=key, types=types)) from None
values.append([vocab.strings.add(v) for v in value])
array = numpy.asarray(values, dtype="uint64") array = numpy.asarray(values, dtype="uint64")
return attrs, array.T return attrs, array.T

View File

@ -5,7 +5,6 @@ import random
import itertools import itertools
import weakref import weakref
import functools import functools
from collections import Iterable as IterableInstance
from contextlib import contextmanager from contextlib import contextmanager
from copy import copy, deepcopy from copy import copy, deepcopy
from pathlib import Path from pathlib import Path
@ -19,7 +18,7 @@ from timeit import default_timer as timer
from .tokens.underscore import Underscore from .tokens.underscore import Underscore
from .vocab import Vocab, create_vocab from .vocab import Vocab, create_vocab
from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis
from .gold import Example from .gold import Example, validate_examples
from .scorer import Scorer from .scorer import Scorer
from .util import create_default_optimizer, registry from .util import create_default_optimizer, registry
from .util import SimpleFrozenDict, combine_score_weights from .util import SimpleFrozenDict, combine_score_weights
@ -935,17 +934,7 @@ class Language:
losses = {} losses = {}
if len(examples) == 0: if len(examples) == 0:
return losses return losses
if not isinstance(examples, IterableInstance): validate_examples(examples, "Language.update")
raise TypeError(
Errors.E978.format(
name="language", method="update", types=type(examples)
)
)
wrong_types = set([type(eg) for eg in examples if not isinstance(eg, Example)])
if wrong_types:
raise TypeError(
Errors.E978.format(name="language", method="update", types=wrong_types)
)
if sgd is None: if sgd is None:
if self._optimizer is None: if self._optimizer is None:
self._optimizer = create_default_optimizer() self._optimizer = create_default_optimizer()
@ -962,7 +951,11 @@ class Language:
proc.update(examples, sgd=None, losses=losses, **component_cfg[name]) proc.update(examples, sgd=None, losses=losses, **component_cfg[name])
if sgd not in (None, False): if sgd not in (None, False):
for name, proc in self.pipeline: for name, proc in self.pipeline:
if name not in exclude and hasattr(proc, "model"): if (
name not in exclude
and hasattr(proc, "model")
and proc.model not in (True, False, None)
):
proc.model.finish_update(sgd) proc.model.finish_update(sgd)
return losses return losses
@ -999,19 +992,7 @@ class Language:
""" """
if len(examples) == 0: if len(examples) == 0:
return return
if not isinstance(examples, IterableInstance): validate_examples(examples, "Language.rehearse")
raise TypeError(
Errors.E978.format(
name="language", method="rehearse", types=type(examples)
)
)
wrong_types = set([type(eg) for eg in examples if not isinstance(eg, Example)])
if wrong_types:
raise TypeError(
Errors.E978.format(
name="language", method="rehearse", types=wrong_types
)
)
if sgd is None: if sgd is None:
if self._optimizer is None: if self._optimizer is None:
self._optimizer = create_default_optimizer() self._optimizer = create_default_optimizer()
@ -1060,7 +1041,15 @@ class Language:
if get_examples is None: if get_examples is None:
get_examples = lambda: [] get_examples = lambda: []
else: # Populate vocab else: # Populate vocab
if not hasattr(get_examples, "__call__"):
err = Errors.E930.format(name="Language", obj=type(get_examples))
raise ValueError(err)
for example in get_examples(): for example in get_examples():
if not isinstance(example, Example):
err = Errors.E978.format(
name="Language.begin_training", types=type(example)
)
raise ValueError(err)
for word in [t.text for t in example.reference]: for word in [t.text for t in example.reference]:
_ = self.vocab[word] # noqa: F841 _ = self.vocab[word] # noqa: F841
if device >= 0: # TODO: do we need this here? if device >= 0: # TODO: do we need this here?
@ -1133,17 +1122,7 @@ class Language:
DOCS: https://spacy.io/api/language#evaluate DOCS: https://spacy.io/api/language#evaluate
""" """
if not isinstance(examples, IterableInstance): validate_examples(examples, "Language.evaluate")
err = Errors.E978.format(
name="language", method="evaluate", types=type(examples)
)
raise TypeError(err)
wrong_types = set([type(eg) for eg in examples if not isinstance(eg, Example)])
if wrong_types:
err = Errors.E978.format(
name="language", method="evaluate", types=wrong_types
)
raise TypeError(err)
if component_cfg is None: if component_cfg is None:
component_cfg = {} component_cfg = {}
if scorer_cfg is None: if scorer_cfg is None:
@ -1663,7 +1642,7 @@ def _fix_pretrained_vectors_name(nlp: Language) -> None:
else: else:
raise ValueError(Errors.E092) raise ValueError(Errors.E092)
for name, proc in nlp.pipeline: for name, proc in nlp.pipeline:
if not hasattr(proc, "cfg"): if not hasattr(proc, "cfg") or not isinstance(proc.cfg, dict):
continue continue
proc.cfg.setdefault("deprecation_fixes", {}) proc.cfg.setdefault("deprecation_fixes", {})
proc.cfg["deprecation_fixes"]["vectors_name"] = nlp.vocab.vectors.name proc.cfg["deprecation_fixes"]["vectors_name"] = nlp.vocab.vectors.name

View File

@ -9,6 +9,7 @@ from .functions import merge_subtokens
from ..language import Language from ..language import Language
from ._parser_internals import nonproj from ._parser_internals import nonproj
from ..scorer import Scorer from ..scorer import Scorer
from ..gold import validate_examples
default_model_config = """ default_model_config = """
@ -147,6 +148,7 @@ cdef class DependencyParser(Parser):
DOCS: https://spacy.io/api/dependencyparser#score DOCS: https://spacy.io/api/dependencyparser#score
""" """
validate_examples(examples, "DependencyParser.score")
def dep_getter(token, attr): def dep_getter(token, attr):
dep = getattr(token, attr) dep = getattr(token, attr)
dep = token.vocab.strings.as_string(dep).lower() dep = token.vocab.strings.as_string(dep).lower()

View File

@ -11,7 +11,7 @@ from ..tokens import Doc
from .pipe import Pipe, deserialize_config from .pipe import Pipe, deserialize_config
from ..language import Language from ..language import Language
from ..vocab import Vocab from ..vocab import Vocab
from ..gold import Example from ..gold import Example, validate_examples
from ..errors import Errors, Warnings from ..errors import Errors, Warnings
from .. import util from .. import util
@ -142,7 +142,7 @@ class EntityLinker(Pipe):
def begin_training( def begin_training(
self, self,
get_examples: Callable[[], Iterable[Example]] = lambda: [], get_examples: Callable[[], Iterable[Example]],
*, *,
pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None, pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None,
sgd: Optional[Optimizer] = None, sgd: Optional[Optimizer] = None,
@ -197,14 +197,9 @@ class EntityLinker(Pipe):
losses.setdefault(self.name, 0.0) losses.setdefault(self.name, 0.0)
if not examples: if not examples:
return losses return losses
validate_examples(examples, "EntityLinker.update")
sentence_docs = [] sentence_docs = []
try:
docs = [eg.predicted for eg in examples] docs = [eg.predicted for eg in examples]
except AttributeError:
types = set([type(eg) for eg in examples])
raise TypeError(
Errors.E978.format(name="EntityLinker", method="update", types=types)
) from None
if set_annotations: if set_annotations:
# This seems simpler than other ways to get that exact output -- but # This seems simpler than other ways to get that exact output -- but
# it does run the model twice :( # it does run the model twice :(
@ -250,6 +245,7 @@ class EntityLinker(Pipe):
return losses return losses
def get_loss(self, examples: Iterable[Example], sentence_encodings): def get_loss(self, examples: Iterable[Example], sentence_encodings):
validate_examples(examples, "EntityLinker.get_loss")
entity_encodings = [] entity_encodings = []
for eg in examples: for eg in examples:
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True) kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)

View File

@ -9,6 +9,7 @@ from ..util import ensure_path, to_disk, from_disk
from ..tokens import Doc, Span from ..tokens import Doc, Span
from ..matcher import Matcher, PhraseMatcher from ..matcher import Matcher, PhraseMatcher
from ..scorer import Scorer from ..scorer import Scorer
from ..gold import validate_examples
DEFAULT_ENT_ID_SEP = "||" DEFAULT_ENT_ID_SEP = "||"
@ -312,6 +313,7 @@ class EntityRuler:
return label return label
def score(self, examples, **kwargs): def score(self, examples, **kwargs):
validate_examples(examples, "EntityRuler.score")
return Scorer.score_spans(examples, "ents", **kwargs) return Scorer.score_spans(examples, "ents", **kwargs)
def from_bytes( def from_bytes(

View File

@ -1,5 +1,4 @@
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
from thinc.api import Model from thinc.api import Model
from .pipe import Pipe from .pipe import Pipe
@ -9,6 +8,7 @@ from ..lookups import Lookups, load_lookups
from ..scorer import Scorer from ..scorer import Scorer
from ..tokens import Doc, Token from ..tokens import Doc, Token
from ..vocab import Vocab from ..vocab import Vocab
from ..gold import validate_examples
from .. import util from .. import util
@ -135,10 +135,10 @@ class Lemmatizer(Pipe):
elif self.mode == "rule": elif self.mode == "rule":
self.lemmatize = self.rule_lemmatize self.lemmatize = self.rule_lemmatize
else: else:
try: mode_attr = f"{self.mode}_lemmatize"
self.lemmatize = getattr(self, f"{self.mode}_lemmatize") if not hasattr(self, mode_attr):
except AttributeError:
raise ValueError(Errors.E1003.format(mode=mode)) raise ValueError(Errors.E1003.format(mode=mode))
self.lemmatize = getattr(self, mode_attr)
self.cache = {} self.cache = {}
@property @property
@ -271,6 +271,7 @@ class Lemmatizer(Pipe):
DOCS: https://spacy.io/api/lemmatizer#score DOCS: https://spacy.io/api/lemmatizer#score
""" """
validate_examples(examples, "Lemmatizer.score")
return Scorer.score_token_attr(examples, "lemma", **kwargs) return Scorer.score_token_attr(examples, "lemma", **kwargs)
def to_disk(self, path, *, exclude=tuple()): def to_disk(self, path, *, exclude=tuple()):

View File

@ -6,15 +6,16 @@ from thinc.api import SequenceCategoricalCrossentropy, Model, Config
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
from ..vocab cimport Vocab from ..vocab cimport Vocab
from ..morphology cimport Morphology from ..morphology cimport Morphology
from ..parts_of_speech import IDS as POS_IDS from ..parts_of_speech import IDS as POS_IDS
from ..symbols import POS from ..symbols import POS
from ..language import Language from ..language import Language
from ..errors import Errors from ..errors import Errors
from .pipe import deserialize_config from .pipe import deserialize_config
from .tagger import Tagger from .tagger import Tagger
from .. import util from .. import util
from ..scorer import Scorer from ..scorer import Scorer
from ..gold import validate_examples
default_model_config = """ default_model_config = """
@ -126,7 +127,7 @@ class Morphologizer(Tagger):
self.cfg["labels_pos"][norm_label] = POS_IDS[pos] self.cfg["labels_pos"][norm_label] = POS_IDS[pos]
return 1 return 1
def begin_training(self, get_examples=lambda: [], *, pipeline=None, sgd=None): def begin_training(self, get_examples, *, pipeline=None, sgd=None):
"""Initialize the pipe for training, using data examples if available. """Initialize the pipe for training, using data examples if available.
get_examples (Callable[[], Iterable[Example]]): Optional function that get_examples (Callable[[], Iterable[Example]]): Optional function that
@ -140,6 +141,9 @@ class Morphologizer(Tagger):
DOCS: https://spacy.io/api/morphologizer#begin_training DOCS: https://spacy.io/api/morphologizer#begin_training
""" """
if not hasattr(get_examples, "__call__"):
err = Errors.E930.format(name="Morphologizer", obj=type(get_examples))
raise ValueError(err)
for example in get_examples(): for example in get_examples():
for i, token in enumerate(example.reference): for i, token in enumerate(example.reference):
pos = token.pos_ pos = token.pos_
@ -192,6 +196,7 @@ class Morphologizer(Tagger):
DOCS: https://spacy.io/api/morphologizer#get_loss DOCS: https://spacy.io/api/morphologizer#get_loss
""" """
validate_examples(examples, "Morphologizer.get_loss")
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False) loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False)
truths = [] truths = []
for eg in examples: for eg in examples:
@ -228,6 +233,7 @@ class Morphologizer(Tagger):
DOCS: https://spacy.io/api/morphologizer#score DOCS: https://spacy.io/api/morphologizer#score
""" """
validate_examples(examples, "Morphologizer.score")
results = {} results = {}
results.update(Scorer.score_token_attr(examples, "pos", **kwargs)) results.update(Scorer.score_token_attr(examples, "pos", **kwargs))
results.update(Scorer.score_token_attr(examples, "morph", **kwargs)) results.update(Scorer.score_token_attr(examples, "morph", **kwargs))

View File

@ -8,6 +8,7 @@ from ..tokens.doc cimport Doc
from .pipe import Pipe from .pipe import Pipe
from .tagger import Tagger from .tagger import Tagger
from ..gold import validate_examples
from ..language import Language from ..language import Language
from ._parser_internals import nonproj from ._parser_internals import nonproj
from ..attrs import POS, ID from ..attrs import POS, ID
@ -80,10 +81,11 @@ class MultitaskObjective(Tagger):
def set_annotations(self, docs, dep_ids): def set_annotations(self, docs, dep_ids):
pass pass
def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None): def begin_training(self, get_examples, pipeline=None, sgd=None):
gold_examples = nonproj.preprocess_training_data(get_examples()) if not hasattr(get_examples, "__call__"):
# for raw_text, doc_annot in gold_tuples: err = Errors.E930.format(name="MultitaskObjective", obj=type(get_examples))
for example in gold_examples: raise ValueError(err)
for example in get_examples():
for token in example.y: for token in example.y:
label = self.make_label(token) label = self.make_label(token)
if label is not None and label not in self.labels: if label is not None and label not in self.labels:
@ -175,7 +177,7 @@ class ClozeMultitask(Pipe):
def set_annotations(self, docs, dep_ids): def set_annotations(self, docs, dep_ids):
pass pass
def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None): def begin_training(self, get_examples, pipeline=None, sgd=None):
self.model.initialize() self.model.initialize()
X = self.model.ops.alloc((5, self.model.get_ref("tok2vec").get_dim("nO"))) X = self.model.ops.alloc((5, self.model.get_ref("tok2vec").get_dim("nO")))
self.model.output_layer.begin_training(X) self.model.output_layer.begin_training(X)
@ -189,6 +191,7 @@ class ClozeMultitask(Pipe):
return tokvecs, vectors return tokvecs, vectors
def get_loss(self, examples, vectors, prediction): def get_loss(self, examples, vectors, prediction):
validate_examples(examples, "ClozeMultitask.get_loss")
# The simplest way to implement this would be to vstack the # The simplest way to implement this would be to vstack the
# token.vector values, but that's a bit inefficient, especially on GPU. # token.vector values, but that's a bit inefficient, especially on GPU.
# Instead we fetch the index into the vectors table for each of our tokens, # Instead we fetch the index into the vectors table for each of our tokens,
@ -206,18 +209,16 @@ class ClozeMultitask(Pipe):
if losses is not None and self.name not in losses: if losses is not None and self.name not in losses:
losses[self.name] = 0. losses[self.name] = 0.
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
try: validate_examples(examples, "ClozeMultitask.rehearse")
predictions, bp_predictions = self.model.begin_update([eg.predicted for eg in examples]) docs = [eg.predicted for eg in examples]
except AttributeError: predictions, bp_predictions = self.model.begin_update()
types = set([type(eg) for eg in examples])
raise TypeError(Errors.E978.format(name="ClozeMultitask", method="rehearse", types=types)) from None
loss, d_predictions = self.get_loss(examples, self.vocab.vectors.data, predictions) loss, d_predictions = self.get_loss(examples, self.vocab.vectors.data, predictions)
bp_predictions(d_predictions) bp_predictions(d_predictions)
if sgd is not None: if sgd is not None:
self.model.finish_update(sgd) self.model.finish_update(sgd)
if losses is not None: if losses is not None:
losses[self.name] += loss losses[self.name] += loss
return losses
def add_label(self, label): def add_label(self, label):
raise NotImplementedError raise NotImplementedError

View File

@ -7,6 +7,7 @@ from ._parser_internals.ner cimport BiluoPushDown
from ..language import Language from ..language import Language
from ..scorer import Scorer from ..scorer import Scorer
from ..gold import validate_examples
default_model_config = """ default_model_config = """
@ -120,4 +121,5 @@ cdef class EntityRecognizer(Parser):
DOCS: https://spacy.io/api/entityrecognizer#score DOCS: https://spacy.io/api/entityrecognizer#score
""" """
validate_examples(examples, "EntityRecognizer.score")
return Scorer.score_spans(examples, "ents", **kwargs) return Scorer.score_spans(examples, "ents", **kwargs)

View File

@ -1,2 +1,5 @@
cdef class Pipe: cdef class Pipe:
cdef public object vocab
cdef public object model
cdef public str name cdef public str name
cdef public object cfg

View File

@ -1,9 +1,10 @@
# cython: infer_types=True, profile=True # cython: infer_types=True, profile=True
import srsly import srsly
from thinc.api import set_dropout_rate, Model
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
from ..util import create_default_optimizer from ..gold import validate_examples
from ..errors import Errors from ..errors import Errors
from .. import util from .. import util
@ -16,7 +17,6 @@ cdef class Pipe:
DOCS: https://spacy.io/api/pipe DOCS: https://spacy.io/api/pipe
""" """
def __init__(self, vocab, model, name, **cfg): def __init__(self, vocab, model, name, **cfg):
"""Initialize a pipeline component. """Initialize a pipeline component.
@ -27,7 +27,10 @@ cdef class Pipe:
DOCS: https://spacy.io/api/pipe#init DOCS: https://spacy.io/api/pipe#init
""" """
raise NotImplementedError self.vocab = vocab
self.model = model
self.name = name
self.cfg = dict(cfg)
def __call__(self, Doc doc): def __call__(self, Doc doc):
"""Apply the pipe to one document. The document is modified in place, """Apply the pipe to one document. The document is modified in place,
@ -68,7 +71,7 @@ cdef class Pipe:
DOCS: https://spacy.io/api/pipe#predict DOCS: https://spacy.io/api/pipe#predict
""" """
raise NotImplementedError raise NotImplementedError(Errors.E931.format(method="predict", name=self.name))
def set_annotations(self, docs, scores): def set_annotations(self, docs, scores):
"""Modify a batch of documents, using pre-computed scores. """Modify a batch of documents, using pre-computed scores.
@ -78,7 +81,43 @@ cdef class Pipe:
DOCS: https://spacy.io/api/pipe#set_annotations DOCS: https://spacy.io/api/pipe#set_annotations
""" """
raise NotImplementedError raise NotImplementedError(Errors.E931.format(method="set_annotations", name=self.name))
def update(self, examples, *, drop=0.0, set_annotations=False, sgd=None, losses=None):
"""Learn from a batch of documents and gold-standard information,
updating the pipe's model. Delegates to predict and get_loss.
examples (Iterable[Example]): A batch of Example objects.
drop (float): The dropout rate.
set_annotations (bool): Whether or not to update the Example objects
with the predictions.
sgd (thinc.api.Optimizer): The optimizer.
losses (Dict[str, float]): Optional record of the loss during training.
Updated using the component name as the key.
RETURNS (Dict[str, float]): The updated losses dictionary.
DOCS: https://spacy.io/api/pipe#update
"""
if losses is None:
losses = {}
if not hasattr(self, "model") or self.model in (None, True, False):
return losses
losses.setdefault(self.name, 0.0)
validate_examples(examples, "Pipe.update")
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
# Handle cases where there are no tokens in any docs.
return
set_dropout_rate(self.model, drop)
scores, bp_scores = self.model.begin_update([eg.predicted for eg in examples])
loss, d_scores = self.get_loss(examples, scores)
bp_scores(d_scores)
if sgd not in (None, False):
self.model.finish_update(sgd)
losses[self.name] += loss
if set_annotations:
docs = [eg.predicted for eg in examples]
self.set_annotations(docs, scores=scores)
return losses
def rehearse(self, examples, *, sgd=None, losses=None, **config): def rehearse(self, examples, *, sgd=None, losses=None, **config):
"""Perform a "rehearsal" update from a batch of data. Rehearsal updates """Perform a "rehearsal" update from a batch of data. Rehearsal updates
@ -107,7 +146,7 @@ cdef class Pipe:
DOCS: https://spacy.io/api/pipe#get_loss DOCS: https://spacy.io/api/pipe#get_loss
""" """
raise NotImplementedError raise NotImplementedError(Errors.E931.format(method="get_loss", name=self.name))
def add_label(self, label): def add_label(self, label):
"""Add an output label, to be predicted by the model. It's possible to """Add an output label, to be predicted by the model. It's possible to
@ -119,7 +158,7 @@ cdef class Pipe:
DOCS: https://spacy.io/api/pipe#add_label DOCS: https://spacy.io/api/pipe#add_label
""" """
raise NotImplementedError raise NotImplementedError(Errors.E931.format(method="add_label", name=self.name))
def create_optimizer(self): def create_optimizer(self):
"""Create an optimizer for the pipeline component. """Create an optimizer for the pipeline component.
@ -128,9 +167,9 @@ cdef class Pipe:
DOCS: https://spacy.io/api/pipe#create_optimizer DOCS: https://spacy.io/api/pipe#create_optimizer
""" """
return create_default_optimizer() return util.create_default_optimizer()
def begin_training(self, get_examples=lambda: [], *, pipeline=None, sgd=None): def begin_training(self, get_examples, *, pipeline=None, sgd=None):
"""Initialize the pipe for training, using data examples if available. """Initialize the pipe for training, using data examples if available.
get_examples (Callable[[], Iterable[Example]]): Optional function that get_examples (Callable[[], Iterable[Example]]): Optional function that

View File

@ -7,6 +7,7 @@ from ..tokens.doc cimport Doc
from .pipe import Pipe from .pipe import Pipe
from ..language import Language from ..language import Language
from ..scorer import Scorer from ..scorer import Scorer
from ..gold import validate_examples
from .. import util from .. import util
@ -58,7 +59,7 @@ class Sentencizer(Pipe):
else: else:
self.punct_chars = set(self.default_punct_chars) self.punct_chars = set(self.default_punct_chars)
def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None): def begin_training(self, get_examples, pipeline=None, sgd=None):
pass pass
def __call__(self, doc): def __call__(self, doc):
@ -158,6 +159,7 @@ class Sentencizer(Pipe):
DOCS: https://spacy.io/api/sentencizer#score DOCS: https://spacy.io/api/sentencizer#score
""" """
validate_examples(examples, "Sentencizer.score")
results = Scorer.score_spans(examples, "sents", **kwargs) results = Scorer.score_spans(examples, "sents", **kwargs)
del results["sents_per_type"] del results["sents_per_type"]
return results return results

View File

@ -9,6 +9,7 @@ from .tagger import Tagger
from ..language import Language from ..language import Language
from ..errors import Errors from ..errors import Errors
from ..scorer import Scorer from ..scorer import Scorer
from ..gold import validate_examples
from .. import util from .. import util
@ -102,6 +103,7 @@ class SentenceRecognizer(Tagger):
DOCS: https://spacy.io/api/sentencerecognizer#get_loss DOCS: https://spacy.io/api/sentencerecognizer#get_loss
""" """
validate_examples(examples, "SentenceRecognizer.get_loss")
labels = self.labels labels = self.labels
loss_func = SequenceCategoricalCrossentropy(names=labels, normalize=False) loss_func = SequenceCategoricalCrossentropy(names=labels, normalize=False)
truths = [] truths = []
@ -121,7 +123,7 @@ class SentenceRecognizer(Tagger):
raise ValueError("nan value when computing loss") raise ValueError("nan value when computing loss")
return float(loss), d_scores return float(loss), d_scores
def begin_training(self, get_examples=lambda: [], *, pipeline=None, sgd=None): def begin_training(self, get_examples, *, pipeline=None, sgd=None):
"""Initialize the pipe for training, using data examples if available. """Initialize the pipe for training, using data examples if available.
get_examples (Callable[[], Iterable[Example]]): Optional function that get_examples (Callable[[], Iterable[Example]]): Optional function that
@ -151,6 +153,7 @@ class SentenceRecognizer(Tagger):
RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_spans. RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_spans.
DOCS: https://spacy.io/api/sentencerecognizer#score DOCS: https://spacy.io/api/sentencerecognizer#score
""" """
validate_examples(examples, "SentenceRecognizer.score")
results = Scorer.score_spans(examples, "sents", **kwargs) results = Scorer.score_spans(examples, "sents", **kwargs)
del results["sents_per_type"] del results["sents_per_type"]
return results return results

View File

@ -1,4 +1,4 @@
from typing import List, Iterable, Optional, Dict, Tuple, Callable from typing import List, Iterable, Optional, Dict, Tuple, Callable, Set
from thinc.types import Floats2d from thinc.types import Floats2d
from thinc.api import SequenceCategoricalCrossentropy, set_dropout_rate, Model from thinc.api import SequenceCategoricalCrossentropy, set_dropout_rate, Model
from thinc.api import Optimizer, Config from thinc.api import Optimizer, Config
@ -6,6 +6,7 @@ from thinc.util import to_numpy
from ..errors import Errors from ..errors import Errors
from ..gold import Example, spans_from_biluo_tags, iob_to_biluo, biluo_to_iob from ..gold import Example, spans_from_biluo_tags, iob_to_biluo, biluo_to_iob
from ..gold import validate_examples
from ..tokens import Doc from ..tokens import Doc
from ..language import Language from ..language import Language
from ..vocab import Vocab from ..vocab import Vocab
@ -127,6 +128,7 @@ class SimpleNER(Pipe):
if losses is None: if losses is None:
losses = {} losses = {}
losses.setdefault("ner", 0.0) losses.setdefault("ner", 0.0)
validate_examples(examples, "SimpleNER.update")
if not any(_has_ner(eg) for eg in examples): if not any(_has_ner(eg) for eg in examples):
return losses return losses
docs = [eg.predicted for eg in examples] docs = [eg.predicted for eg in examples]
@ -142,6 +144,7 @@ class SimpleNER(Pipe):
return losses return losses
def get_loss(self, examples: List[Example], scores) -> Tuple[List[Floats2d], float]: def get_loss(self, examples: List[Example], scores) -> Tuple[List[Floats2d], float]:
validate_examples(examples, "SimpleNER.get_loss")
truths = [] truths = []
for eg in examples: for eg in examples:
tags = eg.get_aligned_ner() tags = eg.get_aligned_ner()
@ -161,14 +164,17 @@ class SimpleNER(Pipe):
def begin_training( def begin_training(
self, self,
get_examples: Callable, get_examples: Callable[[], Iterable[Example]],
pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None, pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None,
sgd: Optional[Optimizer] = None, sgd: Optional[Optimizer] = None,
): ):
all_labels = set()
if not hasattr(get_examples, "__call__"): if not hasattr(get_examples, "__call__"):
gold_tuples = get_examples err = Errors.E930.format(name="SimpleNER", obj=type(get_examples))
get_examples = lambda: gold_tuples raise ValueError(err)
for label in _get_labels(get_examples()): for example in get_examples():
all_labels.update(_get_labels(example))
for label in sorted(all_labels):
self.add_label(label) self.add_label(label)
labels = self.labels labels = self.labels
n_actions = self.model.attrs["get_num_actions"](len(labels)) n_actions = self.model.attrs["get_num_actions"](len(labels))
@ -185,6 +191,7 @@ class SimpleNER(Pipe):
pass pass
def score(self, examples, **kwargs): def score(self, examples, **kwargs):
validate_examples(examples, "SimpleNER.score")
return Scorer.score_spans(examples, "ents", **kwargs) return Scorer.score_spans(examples, "ents", **kwargs)
@ -196,10 +203,9 @@ def _has_ner(example: Example) -> bool:
return False return False
def _get_labels(examples: List[Example]) -> List[str]: def _get_labels(example: Example) -> Set[str]:
labels = set() labels = set()
for eg in examples: for ner_tag in example.get_aligned("ENT_TYPE", as_string=True):
for ner_tag in eg.get_aligned("ENT_TYPE", as_string=True):
if ner_tag != "O" and ner_tag != "-": if ner_tag != "O" and ner_tag != "-":
labels.add(ner_tag) labels.add(ner_tag)
return list(sorted(labels)) return labels

View File

@ -16,6 +16,7 @@ from ..attrs import POS, ID
from ..parts_of_speech import X from ..parts_of_speech import X
from ..errors import Errors, TempErrors, Warnings from ..errors import Errors, TempErrors, Warnings
from ..scorer import Scorer from ..scorer import Scorer
from ..gold import validate_examples
from .. import util from .. import util
@ -187,19 +188,15 @@ class Tagger(Pipe):
if losses is None: if losses is None:
losses = {} losses = {}
losses.setdefault(self.name, 0.0) losses.setdefault(self.name, 0.0)
try: validate_examples(examples, "Tagger.update")
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples): if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
# Handle cases where there are no tokens in any docs. # Handle cases where there are no tokens in any docs.
return return
except AttributeError:
types = set([type(eg) for eg in examples])
raise TypeError(Errors.E978.format(name="Tagger", method="update", types=types)) from None
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
tag_scores, bp_tag_scores = self.model.begin_update( tag_scores, bp_tag_scores = self.model.begin_update([eg.predicted for eg in examples])
[eg.predicted for eg in examples])
for sc in tag_scores: for sc in tag_scores:
if self.model.ops.xp.isnan(sc.sum()): if self.model.ops.xp.isnan(sc.sum()):
raise ValueError("nan value in scores") raise ValueError(Errors.E940)
loss, d_tag_scores = self.get_loss(examples, tag_scores) loss, d_tag_scores = self.get_loss(examples, tag_scores)
bp_tag_scores(d_tag_scores) bp_tag_scores(d_tag_scores)
if sgd not in (None, False): if sgd not in (None, False):
@ -226,11 +223,8 @@ class Tagger(Pipe):
DOCS: https://spacy.io/api/tagger#rehearse DOCS: https://spacy.io/api/tagger#rehearse
""" """
try: validate_examples(examples, "Tagger.rehearse")
docs = [eg.predicted for eg in examples] docs = [eg.predicted for eg in examples]
except AttributeError:
types = set([type(eg) for eg in examples])
raise TypeError(Errors.E978.format(name="Tagger", method="rehearse", types=types)) from None
if self._rehearsal_model is None: if self._rehearsal_model is None:
return return
if not any(len(doc) for doc in docs): if not any(len(doc) for doc in docs):
@ -256,6 +250,7 @@ class Tagger(Pipe):
DOCS: https://spacy.io/api/tagger#get_loss DOCS: https://spacy.io/api/tagger#get_loss
""" """
validate_examples(examples, "Tagger.get_loss")
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False) loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False)
truths = [eg.get_aligned("TAG", as_string=True) for eg in examples] truths = [eg.get_aligned("TAG", as_string=True) for eg in examples]
d_scores, loss = loss_func(scores, truths) d_scores, loss = loss_func(scores, truths)
@ -263,7 +258,7 @@ class Tagger(Pipe):
raise ValueError("nan value when computing loss") raise ValueError("nan value when computing loss")
return float(loss), d_scores return float(loss), d_scores
def begin_training(self, get_examples=lambda: [], *, pipeline=None, sgd=None): def begin_training(self, get_examples, *, pipeline=None, sgd=None):
"""Initialize the pipe for training, using data examples if available. """Initialize the pipe for training, using data examples if available.
get_examples (Callable[[], Iterable[Example]]): Optional function that get_examples (Callable[[], Iterable[Example]]): Optional function that
@ -277,13 +272,12 @@ class Tagger(Pipe):
DOCS: https://spacy.io/api/tagger#begin_training DOCS: https://spacy.io/api/tagger#begin_training
""" """
if not hasattr(get_examples, "__call__"):
err = Errors.E930.format(name="Tagger", obj=type(get_examples))
raise ValueError(err)
tags = set() tags = set()
for example in get_examples(): for example in get_examples():
try: for token in example.y:
y = example.y
except AttributeError:
raise TypeError(Errors.E978.format(name="Tagger", method="begin_training", types=type(example))) from None
for token in y:
tags.add(token.tag_) tags.add(token.tag_)
for tag in sorted(tags): for tag in sorted(tags):
self.add_label(tag) self.add_label(tag)
@ -318,6 +312,7 @@ class Tagger(Pipe):
DOCS: https://spacy.io/api/tagger#score DOCS: https://spacy.io/api/tagger#score
""" """
validate_examples(examples, "Tagger.score")
return Scorer.score_token_attr(examples, "tag", **kwargs) return Scorer.score_token_attr(examples, "tag", **kwargs)
def to_bytes(self, *, exclude=tuple()): def to_bytes(self, *, exclude=tuple()):

View File

@ -5,7 +5,7 @@ import numpy
from .pipe import Pipe from .pipe import Pipe
from ..language import Language from ..language import Language
from ..gold import Example from ..gold import Example, validate_examples
from ..errors import Errors from ..errors import Errors
from ..scorer import Scorer from ..scorer import Scorer
from .. import util from .. import util
@ -209,15 +209,10 @@ class TextCategorizer(Pipe):
if losses is None: if losses is None:
losses = {} losses = {}
losses.setdefault(self.name, 0.0) losses.setdefault(self.name, 0.0)
try: validate_examples(examples, "TextCategorizer.update")
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples): if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
# Handle cases where there are no tokens in any docs. # Handle cases where there are no tokens in any docs.
return losses return losses
except AttributeError:
types = set([type(eg) for eg in examples])
raise TypeError(
Errors.E978.format(name="TextCategorizer", method="update", types=types)
) from None
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
scores, bp_scores = self.model.begin_update([eg.predicted for eg in examples]) scores, bp_scores = self.model.begin_update([eg.predicted for eg in examples])
loss, d_scores = self.get_loss(examples, scores) loss, d_scores = self.get_loss(examples, scores)
@ -252,19 +247,12 @@ class TextCategorizer(Pipe):
DOCS: https://spacy.io/api/textcategorizer#rehearse DOCS: https://spacy.io/api/textcategorizer#rehearse
""" """
if losses is not None: if losses is not None:
losses.setdefault(self.name, 0.0) losses.setdefault(self.name, 0.0)
if self._rehearsal_model is None: if self._rehearsal_model is None:
return losses return losses
try: validate_examples(examples, "TextCategorizer.rehearse")
docs = [eg.predicted for eg in examples] docs = [eg.predicted for eg in examples]
except AttributeError:
types = set([type(eg) for eg in examples])
err = Errors.E978.format(
name="TextCategorizer", method="rehearse", types=types
)
raise TypeError(err) from None
if not any(len(doc) for doc in docs): if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs. # Handle cases where there are no tokens in any docs.
return losses return losses
@ -303,6 +291,7 @@ class TextCategorizer(Pipe):
DOCS: https://spacy.io/api/textcategorizer#get_loss DOCS: https://spacy.io/api/textcategorizer#get_loss
""" """
validate_examples(examples, "TextCategorizer.get_loss")
truths, not_missing = self._examples_to_truth(examples) truths, not_missing = self._examples_to_truth(examples)
not_missing = self.model.ops.asarray(not_missing) not_missing = self.model.ops.asarray(not_missing)
d_scores = (scores - truths) / scores.shape[0] d_scores = (scores - truths) / scores.shape[0]
@ -338,7 +327,7 @@ class TextCategorizer(Pipe):
def begin_training( def begin_training(
self, self,
get_examples: Callable[[], Iterable[Example]] = lambda: [], get_examples: Callable[[], Iterable[Example]],
*, *,
pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None, pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None,
sgd: Optional[Optimizer] = None, sgd: Optional[Optimizer] = None,
@ -356,21 +345,20 @@ class TextCategorizer(Pipe):
DOCS: https://spacy.io/api/textcategorizer#begin_training DOCS: https://spacy.io/api/textcategorizer#begin_training
""" """
# TODO: begin_training is not guaranteed to see all data / labels ? if not hasattr(get_examples, "__call__"):
examples = list(get_examples()) err = Errors.E930.format(name="TextCategorizer", obj=type(get_examples))
for example in examples: raise ValueError(err)
try: subbatch = [] # Select a subbatch of examples to initialize the model
y = example.y for example in get_examples():
except AttributeError: if len(subbatch) < 2:
err = Errors.E978.format( subbatch.append(example)
name="TextCategorizer", method="update", types=type(example) for cat in example.y.cats:
)
raise TypeError(err) from None
for cat in y.cats:
self.add_label(cat) self.add_label(cat)
self.require_labels() self.require_labels()
docs = [eg.reference for eg in subbatch]
if not docs: # need at least one doc
docs = [Doc(self.vocab, words=["hello"])] docs = [Doc(self.vocab, words=["hello"])]
truths, _ = self._examples_to_truth(examples) truths, _ = self._examples_to_truth(subbatch)
self.set_output(len(self.labels)) self.set_output(len(self.labels))
self.model.initialize(X=docs, Y=truths) self.model.initialize(X=docs, Y=truths)
if sgd is None: if sgd is None:
@ -392,6 +380,7 @@ class TextCategorizer(Pipe):
DOCS: https://spacy.io/api/textcategorizer#score DOCS: https://spacy.io/api/textcategorizer#score
""" """
validate_examples(examples, "TextCategorizer.score")
return Scorer.score_cats( return Scorer.score_cats(
examples, examples,
"cats", "cats",

View File

@ -2,7 +2,7 @@ from typing import Iterator, Sequence, Iterable, Optional, Dict, Callable, List,
from thinc.api import Model, set_dropout_rate, Optimizer, Config from thinc.api import Model, set_dropout_rate, Optimizer, Config
from .pipe import Pipe from .pipe import Pipe
from ..gold import Example from ..gold import Example, validate_examples
from ..tokens import Doc from ..tokens import Doc
from ..vocab import Vocab from ..vocab import Vocab
from ..language import Language from ..language import Language
@ -166,9 +166,8 @@ class Tok2Vec(Pipe):
""" """
if losses is None: if losses is None:
losses = {} losses = {}
validate_examples(examples, "Tok2Vec.update")
docs = [eg.predicted for eg in examples] docs = [eg.predicted for eg in examples]
if isinstance(docs, Doc):
docs = [docs]
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
tokvecs, bp_tokvecs = self.model.begin_update(docs) tokvecs, bp_tokvecs = self.model.begin_update(docs)
d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs] d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
@ -204,7 +203,7 @@ class Tok2Vec(Pipe):
def begin_training( def begin_training(
self, self,
get_examples: Callable[[], Iterable[Example]] = lambda: [], get_examples: Callable[[], Iterable[Example]],
*, *,
pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None, pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None,
sgd: Optional[Optimizer] = None, sgd: Optional[Optimizer] = None,

View File

@ -8,11 +8,8 @@ from ..ml.parser_model cimport WeightsC, ActivationsC, SizesC
cdef class Parser(Pipe): cdef class Parser(Pipe):
cdef readonly Vocab vocab
cdef public object model
cdef public object _rehearsal_model cdef public object _rehearsal_model
cdef readonly TransitionSystem moves cdef readonly TransitionSystem moves
cdef readonly object cfg
cdef public object _multitasks cdef public object _multitasks
cdef void _parseC(self, StateC** states, cdef void _parseC(self, StateC** states,

View File

@ -8,22 +8,21 @@ from libc.string cimport memset
from libc.stdlib cimport calloc, free from libc.stdlib cimport calloc, free
import srsly import srsly
from thinc.api import set_dropout_rate
import numpy.random
import numpy
import warnings
from ._parser_internals.stateclass cimport StateClass from ._parser_internals.stateclass cimport StateClass
from ..ml.parser_model cimport alloc_activations, free_activations from ..ml.parser_model cimport alloc_activations, free_activations
from ..ml.parser_model cimport predict_states, arg_max_if_valid from ..ml.parser_model cimport predict_states, arg_max_if_valid
from ..ml.parser_model cimport WeightsC, ActivationsC, SizesC, cpu_log_loss from ..ml.parser_model cimport WeightsC, ActivationsC, SizesC, cpu_log_loss
from ..ml.parser_model cimport get_c_weights, get_c_sizes from ..ml.parser_model cimport get_c_weights, get_c_sizes
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
from ..gold import validate_examples
from ..errors import Errors, Warnings from ..errors import Errors, Warnings
from .. import util from .. import util
from ..util import create_default_optimizer
from thinc.api import set_dropout_rate
import numpy.random
import numpy
import warnings
cdef class Parser(Pipe): cdef class Parser(Pipe):
@ -266,6 +265,7 @@ cdef class Parser(Pipe):
if losses is None: if losses is None:
losses = {} losses = {}
losses.setdefault(self.name, 0.) losses.setdefault(self.name, 0.)
validate_examples(examples, "Parser.update")
for multitask in self._multitasks: for multitask in self._multitasks:
multitask.update(examples, drop=drop, sgd=sgd) multitask.update(examples, drop=drop, sgd=sgd)
n_examples = len([eg for eg in examples if self.moves.has_gold(eg)]) n_examples = len([eg for eg in examples if self.moves.has_gold(eg)])
@ -329,7 +329,7 @@ cdef class Parser(Pipe):
if self._rehearsal_model is None: if self._rehearsal_model is None:
return None return None
losses.setdefault(self.name, 0.) losses.setdefault(self.name, 0.)
validate_examples(examples, "Parser.rehearse")
docs = [eg.predicted for eg in examples] docs = [eg.predicted for eg in examples]
states = self.moves.init_batch(docs) states = self.moves.init_batch(docs)
# This is pretty dirty, but the NER can resize itself in init_batch, # This is pretty dirty, but the NER can resize itself in init_batch,
@ -398,21 +398,18 @@ cdef class Parser(Pipe):
losses[self.name] += (d_scores**2).sum() losses[self.name] += (d_scores**2).sum()
return d_scores return d_scores
def create_optimizer(self):
return create_default_optimizer()
def set_output(self, nO): def set_output(self, nO):
self.model.attrs["resize_output"](self.model, nO) self.model.attrs["resize_output"](self.model, nO)
def begin_training(self, get_examples, pipeline=None, sgd=None, **kwargs): def begin_training(self, get_examples, pipeline=None, sgd=None, **kwargs):
if not hasattr(get_examples, "__call__"):
err = Errors.E930.format(name="DependencyParser/EntityRecognizer", obj=type(get_examples))
raise ValueError(err)
self.cfg.update(kwargs) self.cfg.update(kwargs)
lexeme_norms = self.vocab.lookups.get_table("lexeme_norm", {}) lexeme_norms = self.vocab.lookups.get_table("lexeme_norm", {})
if len(lexeme_norms) == 0 and self.vocab.lang in util.LEXEME_NORM_LANGS: if len(lexeme_norms) == 0 and self.vocab.lang in util.LEXEME_NORM_LANGS:
langs = ", ".join(util.LEXEME_NORM_LANGS) langs = ", ".join(util.LEXEME_NORM_LANGS)
warnings.warn(Warnings.W033.format(model="parser or NER", langs=langs)) warnings.warn(Warnings.W033.format(model="parser or NER", langs=langs))
if not hasattr(get_examples, '__call__'):
gold_tuples = get_examples
get_examples = lambda: gold_tuples
actions = self.moves.get_actions( actions = self.moves.get_actions(
examples=get_examples(), examples=get_examples(),
min_freq=self.cfg['min_action_freq'], min_freq=self.cfg['min_action_freq'],

View File

@ -18,7 +18,7 @@ def test_doc_add_entities_set_ents_iob(en_vocab):
cfg = {"model": DEFAULT_NER_MODEL} cfg = {"model": DEFAULT_NER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.make_from_config(cfg, validate=True)["model"]
ner = EntityRecognizer(en_vocab, model, **config) ner = EntityRecognizer(en_vocab, model, **config)
ner.begin_training([]) ner.begin_training(lambda: [])
ner(doc) ner(doc)
assert len(list(doc.ents)) == 0 assert len(list(doc.ents)) == 0
assert [w.ent_iob_ for w in doc] == (["O"] * len(doc)) assert [w.ent_iob_ for w in doc] == (["O"] * len(doc))
@ -41,7 +41,7 @@ def test_ents_reset(en_vocab):
cfg = {"model": DEFAULT_NER_MODEL} cfg = {"model": DEFAULT_NER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.make_from_config(cfg, validate=True)["model"]
ner = EntityRecognizer(en_vocab, model, **config) ner = EntityRecognizer(en_vocab, model, **config)
ner.begin_training([]) ner.begin_training(lambda: [])
ner(doc) ner(doc)
assert [t.ent_iob_ for t in doc] == (["O"] * len(doc)) assert [t.ent_iob_ for t in doc] == (["O"] * len(doc))
doc.ents = list(doc.ents) doc.ents = list(doc.ents)

View File

@ -35,7 +35,7 @@ def test_init_parser(parser):
def _train_parser(parser): def _train_parser(parser):
fix_random_seed(1) fix_random_seed(1)
parser.add_label("left") parser.add_label("left")
parser.begin_training([], **parser.cfg) parser.begin_training(lambda: [], **parser.cfg)
sgd = Adam(0.001) sgd = Adam(0.001)
for i in range(5): for i in range(5):
@ -75,7 +75,7 @@ def test_add_label_deserializes_correctly():
ner1.add_label("C") ner1.add_label("C")
ner1.add_label("B") ner1.add_label("B")
ner1.add_label("A") ner1.add_label("A")
ner1.begin_training([]) ner1.begin_training(lambda: [])
ner2 = EntityRecognizer(Vocab(), model, **config) ner2 = EntityRecognizer(Vocab(), model, **config)
# the second model needs to be resized before we can call from_bytes # the second model needs to be resized before we can call from_bytes

View File

@ -28,7 +28,7 @@ def parser(vocab):
parser.cfg["hidden_width"] = 32 parser.cfg["hidden_width"] = 32
# parser.add_label('right') # parser.add_label('right')
parser.add_label("left") parser.add_label("left")
parser.begin_training([], **parser.cfg) parser.begin_training(lambda: [], **parser.cfg)
sgd = Adam(0.001) sgd = Adam(0.001)
for i in range(10): for i in range(10):

View File

@ -136,7 +136,7 @@ def test_kb_undefined(nlp):
"""Test that the EL can't train without defining a KB""" """Test that the EL can't train without defining a KB"""
entity_linker = nlp.add_pipe("entity_linker", config={}) entity_linker = nlp.add_pipe("entity_linker", config={})
with pytest.raises(ValueError): with pytest.raises(ValueError):
entity_linker.begin_training() entity_linker.begin_training(lambda: [])
def test_kb_empty(nlp): def test_kb_empty(nlp):
@ -145,7 +145,7 @@ def test_kb_empty(nlp):
entity_linker = nlp.add_pipe("entity_linker", config=config) entity_linker = nlp.add_pipe("entity_linker", config=config)
assert len(entity_linker.kb) == 0 assert len(entity_linker.kb) == 0
with pytest.raises(ValueError): with pytest.raises(ValueError):
entity_linker.begin_training() entity_linker.begin_training(lambda: [])
def test_candidate_generation(nlp): def test_candidate_generation(nlp):
@ -249,7 +249,7 @@ def test_preserving_links_asdoc(nlp):
ruler.add_patterns(patterns) ruler.add_patterns(patterns)
el_config = {"kb": {"@assets": "myLocationsKB.v1"}, "incl_prior": False} el_config = {"kb": {"@assets": "myLocationsKB.v1"}, "incl_prior": False}
el_pipe = nlp.add_pipe("entity_linker", config=el_config, last=True) el_pipe = nlp.add_pipe("entity_linker", config=el_config, last=True)
el_pipe.begin_training() el_pipe.begin_training(lambda: [])
el_pipe.incl_context = False el_pipe.incl_context = False
el_pipe.incl_prior = True el_pipe.incl_prior = True

View File

@ -54,7 +54,7 @@ def test_textcat_learns_multilabel():
textcat = TextCategorizer(nlp.vocab, width=8) textcat = TextCategorizer(nlp.vocab, width=8)
for letter in letters: for letter in letters:
textcat.add_label(letter) textcat.add_label(letter)
optimizer = textcat.begin_training() optimizer = textcat.begin_training(lambda: [])
for i in range(30): for i in range(30):
losses = {} losses = {}
examples = [Example.from_dict(doc, {"cats": cats}) for doc, cat in docs] examples = [Example.from_dict(doc, {"cats": cats}) for doc, cat in docs]

View File

@ -20,7 +20,7 @@ def test_issue2564():
nlp = Language() nlp = Language()
tagger = nlp.add_pipe("tagger") tagger = nlp.add_pipe("tagger")
tagger.add_label("A") tagger.add_label("A")
tagger.begin_training() tagger.begin_training(lambda: [])
doc = nlp("hello world") doc = nlp("hello world")
assert doc.is_tagged assert doc.is_tagged
docs = nlp.pipe(["hello", "world"]) docs = nlp.pipe(["hello", "world"])

View File

@ -303,7 +303,7 @@ def test_issue4313():
config = {} config = {}
ner = nlp.create_pipe("ner", config=config) ner = nlp.create_pipe("ner", config=config)
ner.add_label("SOME_LABEL") ner.add_label("SOME_LABEL")
ner.begin_training([]) ner.begin_training(lambda: [])
# add a new label to the doc # add a new label to the doc
doc = nlp("What do you think about Apple ?") doc = nlp("What do you think about Apple ?")
assert len(ner.labels) == 1 assert len(ner.labels) == 1

View File

@ -62,7 +62,7 @@ def tagger():
# need to add model for two reasons: # need to add model for two reasons:
# 1. no model leads to error in serialization, # 1. no model leads to error in serialization,
# 2. the affected line is the one for model serialization # 2. the affected line is the one for model serialization
tagger.begin_training(pipeline=nlp.pipeline) tagger.begin_training(lambda: [], pipeline=nlp.pipeline)
return tagger return tagger
@ -81,7 +81,7 @@ def entity_linker():
# need to add model for two reasons: # need to add model for two reasons:
# 1. no model leads to error in serialization, # 1. no model leads to error in serialization,
# 2. the affected line is the one for model serialization # 2. the affected line is the one for model serialization
entity_linker.begin_training(pipeline=nlp.pipeline) entity_linker.begin_training(lambda: [], pipeline=nlp.pipeline)
return entity_linker return entity_linker

View File

@ -24,6 +24,7 @@ from .util import registry
from .attrs import intify_attrs from .attrs import intify_attrs
from .symbols import ORTH from .symbols import ORTH
from .scorer import Scorer from .scorer import Scorer
from .gold import validate_examples
cdef class Tokenizer: cdef class Tokenizer:
@ -712,6 +713,7 @@ cdef class Tokenizer:
return tokens return tokens
def score(self, examples, **kwargs): def score(self, examples, **kwargs):
validate_examples(examples, "Tokenizer.score")
return Scorer.score_tokenization(examples) return Scorer.score_tokenization(examples)
def to_disk(self, path, **kwargs): def to_disk(self, path, **kwargs):

View File

@ -45,18 +45,12 @@ Create a new pipeline instance. In your application, you would normally use a
shortcut for this and instantiate the component using its string name and shortcut for this and instantiate the component using its string name and
[`nlp.add_pipe`](/api/language#create_pipe). [`nlp.add_pipe`](/api/language#create_pipe).
<Infobox variant="danger">
This method needs to be overwritten with your own custom `__init__` method.
</Infobox>
| Name | Type | Description | | Name | Type | Description |
| ------- | ------------------------------------------ | ------------------------------------------------------------------------------------------- | | ------- | ------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------- |
| `vocab` | `Vocab` | The shared vocabulary. | | `vocab` | `Vocab` | The shared vocabulary. |
| `model` | [`Model`](https://thinc.ai/docs/api-model) | The Thinc [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. | | `model` | [`Model`](https://thinc.ai/docs/api-model) | The Thinc [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. |
| `name` | str | String name of the component instance. Used to add entries to the `losses` during training. | | `name` | str | String name of the component instance. Used to add entries to the `losses` during training. |
| `**cfg` | | Additional config parameters and settings. | | `**cfg` | | Additional config parameters and settings. Will be available as the dictionary `Pipe.cfg` and is serialized with the component. |
## Pipe.\_\_call\_\_ {#call tag="method"} ## Pipe.\_\_call\_\_ {#call tag="method"}
@ -182,12 +176,6 @@ method.
Learn from a batch of [`Example`](/api/example) objects containing the Learn from a batch of [`Example`](/api/example) objects containing the
predictions and gold-standard annotations, and update the component's model. predictions and gold-standard annotations, and update the component's model.
<Infobox variant="danger">
This method needs to be overwritten with your own custom `update` method.
</Infobox>
> #### Example > #### Example
> >
> ```python > ```python
@ -384,6 +372,15 @@ Load the pipe from a bytestring. Modifies the object in place and returns it.
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. | | `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
| **RETURNS** | `Pipe` | The pipe. | | **RETURNS** | `Pipe` | The pipe. |
## Attributes {#attributes}
| Name | Type | Description |
| ------- | ------------------------------------------ | ----------------------------------------------------------------------------------------------------- |
| `vocab` | [`Vocab`](/api/vocab) | The shared vocabulary that's passed in on initialization. |
| `model` | [`Model`](https://thinc.ai/docs/api-model) | The model powering the component. |
| `name` | str | The name of the component instance in the pipeline. Can be used in the losses. |
| `cfg` | dict | Keyword arguments passed to [`Pipe.__init__`](/api/pipe#init). Will be serialized with the component. |
## Serialization fields {#serialization-fields} ## Serialization fields {#serialization-fields}
During serialization, spaCy will export several data fields used to restore During serialization, spaCy will export several data fields used to restore

View File

@ -5,7 +5,6 @@ menu:
- ['Processing Text', 'processing'] - ['Processing Text', 'processing']
- ['How Pipelines Work', 'pipelines'] - ['How Pipelines Work', 'pipelines']
- ['Custom Components', 'custom-components'] - ['Custom Components', 'custom-components']
# - ['Trainable Components', 'trainable-components']
- ['Extension Attributes', 'custom-components-attributes'] - ['Extension Attributes', 'custom-components-attributes']
- ['Plugins & Wrappers', 'plugins'] - ['Plugins & Wrappers', 'plugins']
--- ---
@ -885,15 +884,117 @@ available, falls back to looking up the regular factory name.
</Infobox> </Infobox>
<!-- TODO: ### Trainable components {#trainable-components new="3"}
## Trainable components {#trainable-components new="3"}
spaCy's [`Pipe`](/api/pipe) class helps you implement your own trainable spaCy's [`Pipe`](/api/pipe) class helps you implement your own trainable
components that have their own model instance, make predictions over `Doc` components that have their own model instance, make predictions over `Doc`
objects and can be updated using [`spacy train`](/api/cli#train). This lets you objects and can be updated using [`spacy train`](/api/cli#train). This lets you
plug fully custom machine learning components into your pipeline. plug fully custom machine learning components into your pipeline. You'll need
the following:
---> 1. **Model:** A Thinc [`Model`](https://thinc.ai/docs/api-model) instance. This
can be a model using [layers](https://thinc.ai/docs/api-layers) implemented
in Thinc, or a [wrapped model](https://thinc.ai/docs/usage-frameworks)
implemented in PyTorch, TensorFlow, MXNet or a fully custom solution. The
model must take a list of [`Doc`](/api/doc) objects as input and can have any
type of output.
2. **Pipe subclass:** A subclass of [`Pipe`](/api/pipe) that implements at least
two methods: [`Pipe.predict`](/api/pipe#predict) and
[`Pipe.set_annotations`](/api/pipe#set_annotations).
3. **Component factory:** A component factory registered with
[`@Language.factory`](/api/language#factory) that takes the `nlp` object and
component `name` and optional settings provided by the config and returns an
instance of your trainable component.
> #### Example
>
> ```python
> from spacy.pipeline import Pipe
> from spacy.language import Language
>
> class TrainableComponent(Pipe):
> def predict(self, docs):
> ...
>
> def set_annotations(self, docs, scores):
> ...
>
> @Language.factory("my_trainable_component")
> def make_component(nlp, name, model):
> return TrainableComponent(nlp.vocab, model, name=name)
> ```
| Name | Description |
| ---------------------------------------------- | ------------------------------------------------------------------------------------------------------------------- |
| [`predict`](/api/pipe#predict) | Apply the component's model to a batch of [`Doc`](/api/doc) objects (without modifying them) and return the scores. |
| [`set_annotations`](/api/pipe#set_annotations) | Modify a batch of [`Doc`](/api/doc) objects, using pre-computed scores generated by `predict`. |
By default, [`Pipe.__init__`](/api/pipe#init) takes the shared vocab, the
[`Model`](https://thinc.ai/docs/api-model) and the name of the component
instance in the pipeline, which you can use as a key in the losses. All other
keyword arguments will become available as [`Pipe.cfg`](/api/pipe#cfg) and will
also be serialized with the component.
<Accordion title="Why components should be passed a Model instance, not create it" spaced>
spaCy's [config system](/usage/training#config) resolves the config describing
the pipeline components and models **bottom-up**. This means that it will
_first_ create a `Model` from a [registered architecture](/api/architectures),
validate its arguments and _then_ pass the object forward to the component. This
means that the config can express very complex, nested trees of objects but
the objects don't have to pass the model settings all the way down to the
components. It also makes the components more **modular** and lets you swap
different architectures in your config, and re-use model definitions.
```ini
### config.cfg (excerpt)
[components]
[components.textcat]
factory = "textcat"
labels = []
# This function is created and then passed to the "textcat" component as
# the argument "model"
[components.textcat.model]
@architectures = "spacy.TextCatEnsemble.v1"
exclusive_classes = false
pretrained_vectors = null
width = 64
conv_depth = 2
embed_size = 2000
window_size = 1
ngram_size = 1
dropout = null
[components.other_textcat]
factory = "textcat"
# This references the [components.textcat.model] block above
model = ${components.textcat.model}
labels = []
```
Your trainable pipeline component factories should therefore always take a
`model` argument instead of instantiating the
[`Model`](https://thinc.ai/docs/api-model) inside the component. To register
custom architectures, you can use the
[`@spacy.registry.architectures`](/api/top-level#registry) decorator. Also see
the [training guide](/usage/training#config) for details.
</Accordion>
For some use cases, it makes sense to also overwrite additional methods to
customize how the model is updated from examples, how it's initialized, how the
loss is calculated and to add evaluation scores to the training output.
| Name | Description |
| -------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| [`update`](/api/pipe#update) | Learn from a batch of [`Example`](/api/example) objects containing the predictions and gold-standard annotations, and update the component's model. |
| [`begin_training`](/api/pipe#begin_training) | Initialize the model. Typically calls into [`Model.initialize`](https://thinc.ai/docs/api-model#initialize) and [`Pipe.create_optimizer`](/api/pipe#create_optimizer) if no optimizer is provided. |
| [`get_loss`](/api/pipe#get_loss) | Return a tuple of the loss and the gradient for a batch of [`Example`](/api/example) objects. |
| [`score`](/api/pipe#score) | Score a batch of [`Example`](/api/example) objects and return a dictionary of scores. The [`@Language.factory`](/api/language#factory) decorator can define the `default_socre_weights` of the component to decide which keys of the scores to display during training and how they count towards the final score. |
<!-- TODO: add more details, examples and maybe an example project -->
## Extension attributes {#custom-components-attributes new="2"} ## Extension attributes {#custom-components-attributes new="2"}