mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	Tidy up pipes (#5906)
* Tidy up pipes * Fix init, defaults and raise custom errors * Update docs * Update docs [ci skip] * Apply suggestions from code review Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> * Tidy up error handling and validation, fix consistency * Simplify get_examples check * Remove unused import [ci skip] Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com>
This commit is contained in:
		
							parent
							
								
									b7ec06e331
								
							
						
					
					
						commit
						950832f087
					
				|  | @ -295,7 +295,11 @@ def train_while_improving( | |||
|                 nlp.rehearse(raw_batch, sgd=optimizer, losses=losses, exclude=exclude) | ||||
|         # TODO: refactor this so we don't have to run it separately in here | ||||
|         for name, proc in nlp.pipeline: | ||||
|             if name not in exclude and hasattr(proc, "model"): | ||||
|             if ( | ||||
|                 name not in exclude | ||||
|                 and hasattr(proc, "model") | ||||
|                 and proc.model not in (True, False, None) | ||||
|             ): | ||||
|                 proc.model.finish_update(optimizer) | ||||
|         optimizer.step_schedules() | ||||
|         if not (step % eval_frequency): | ||||
|  |  | |||
|  | @ -482,6 +482,15 @@ class Errors: | |||
|     E199 = ("Unable to merge 0-length span at doc[{start}:{end}].") | ||||
| 
 | ||||
|     # TODO: fix numbering after merging develop into master | ||||
|     E930 = ("Received invalid get_examples callback in {name}.begin_training. " | ||||
|             "Expected function that returns an iterable of Example objects but " | ||||
|             "got: {obj}") | ||||
|     E931 = ("Encountered Pipe subclass without Pipe.{method} method in component " | ||||
|             "'{name}'. If the component is trainable and you want to use this " | ||||
|             "method, make sure it's overwritten on the subclass. If your " | ||||
|             "component isn't trainable, add a method that does nothing or " | ||||
|             "don't use the Pipe base class.") | ||||
|     E940 = ("Found NaN values in scores.") | ||||
|     E941 = ("Can't find model '{name}'. It looks like you're trying to load a " | ||||
|             "model from a shortcut, which is deprecated as of spaCy v3.0. To " | ||||
|             "load the model, use its full name instead:\n\n" | ||||
|  | @ -578,8 +587,7 @@ class Errors: | |||
|             "but received None.") | ||||
|     E977 = ("Can not compare a MorphAnalysis with a string object. " | ||||
|             "This is likely a bug in spaCy, so feel free to open an issue.") | ||||
|     E978 = ("The '{method}' method of {name} takes a list of Example objects, " | ||||
|             "but found {types} instead.") | ||||
|     E978 = ("The {name} method takes a list of Example objects, but got: {types}") | ||||
|     E979 = ("Cannot convert {type} to an Example object.") | ||||
|     E980 = ("Each link annotation should refer to a dictionary with at most one " | ||||
|             "identifier mapping to 1.0, and all others to 0.0.") | ||||
|  |  | |||
|  | @ -1,5 +1,5 @@ | |||
| from .corpus import Corpus  # noqa: F401 | ||||
| from .example import Example  # noqa: F401 | ||||
| from .example import Example, validate_examples  # noqa: F401 | ||||
| from .align import Alignment  # noqa: F401 | ||||
| from .iob_utils import iob_to_biluo, biluo_to_iob  # noqa: F401 | ||||
| from .iob_utils import biluo_tags_from_offsets, offsets_from_biluo_tags  # noqa: F401 | ||||
|  |  | |||
|  | @ -1,5 +1,5 @@ | |||
| from collections import Iterable as IterableInstance | ||||
| import warnings | ||||
| 
 | ||||
| import numpy | ||||
| 
 | ||||
| from ..tokens.doc cimport Doc | ||||
|  | @ -26,6 +26,22 @@ cpdef Doc annotations2doc(vocab, tok_annot, doc_annot): | |||
|     return output | ||||
| 
 | ||||
| 
 | ||||
| def validate_examples(examples, method): | ||||
|     """Check that a batch of examples received during processing is valid. | ||||
|     This function lives here to prevent circular imports. | ||||
| 
 | ||||
|     examples (Iterable[Examples]): A batch of examples. | ||||
|     method (str): The method name to show in error messages. | ||||
|     """ | ||||
|     if not isinstance(examples, IterableInstance): | ||||
|         err = Errors.E978.format(name=method, types=type(examples)) | ||||
|         raise TypeError(err) | ||||
|     wrong = set([type(eg) for eg in examples if not isinstance(eg, Example)]) | ||||
|     if wrong: | ||||
|         err = Errors.E978.format(name=method, types=wrong) | ||||
|         raise TypeError(err) | ||||
| 
 | ||||
| 
 | ||||
| cdef class Example: | ||||
|     def __init__(self, Doc predicted, Doc reference, *, alignment=None): | ||||
|         if predicted is None: | ||||
|  | @ -263,12 +279,10 @@ def _annot2array(vocab, tok_annot, doc_annot): | |||
|             values.append([vocab.morphology.add(v) for v in value]) | ||||
|         else: | ||||
|             attrs.append(key) | ||||
|             try: | ||||
|                 values.append([vocab.strings.add(v) for v in value]) | ||||
|             except TypeError: | ||||
|                 types= set([type(v) for v in value]) | ||||
|             if not all(isinstance(v, str) for v in value): | ||||
|                 types = set([type(v) for v in value]) | ||||
|                 raise TypeError(Errors.E969.format(field=key, types=types)) from None | ||||
| 
 | ||||
|             values.append([vocab.strings.add(v) for v in value]) | ||||
|     array = numpy.asarray(values, dtype="uint64") | ||||
|     return attrs, array.T | ||||
| 
 | ||||
|  |  | |||
|  | @ -5,7 +5,6 @@ import random | |||
| import itertools | ||||
| import weakref | ||||
| import functools | ||||
| from collections import Iterable as IterableInstance | ||||
| from contextlib import contextmanager | ||||
| from copy import copy, deepcopy | ||||
| from pathlib import Path | ||||
|  | @ -19,7 +18,7 @@ from timeit import default_timer as timer | |||
| from .tokens.underscore import Underscore | ||||
| from .vocab import Vocab, create_vocab | ||||
| from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis | ||||
| from .gold import Example | ||||
| from .gold import Example, validate_examples | ||||
| from .scorer import Scorer | ||||
| from .util import create_default_optimizer, registry | ||||
| from .util import SimpleFrozenDict, combine_score_weights | ||||
|  | @ -935,17 +934,7 @@ class Language: | |||
|             losses = {} | ||||
|         if len(examples) == 0: | ||||
|             return losses | ||||
|         if not isinstance(examples, IterableInstance): | ||||
|             raise TypeError( | ||||
|                 Errors.E978.format( | ||||
|                     name="language", method="update", types=type(examples) | ||||
|                 ) | ||||
|             ) | ||||
|         wrong_types = set([type(eg) for eg in examples if not isinstance(eg, Example)]) | ||||
|         if wrong_types: | ||||
|             raise TypeError( | ||||
|                 Errors.E978.format(name="language", method="update", types=wrong_types) | ||||
|             ) | ||||
|         validate_examples(examples, "Language.update") | ||||
|         if sgd is None: | ||||
|             if self._optimizer is None: | ||||
|                 self._optimizer = create_default_optimizer() | ||||
|  | @ -962,7 +951,11 @@ class Language: | |||
|             proc.update(examples, sgd=None, losses=losses, **component_cfg[name]) | ||||
|         if sgd not in (None, False): | ||||
|             for name, proc in self.pipeline: | ||||
|                 if name not in exclude and hasattr(proc, "model"): | ||||
|                 if ( | ||||
|                     name not in exclude | ||||
|                     and hasattr(proc, "model") | ||||
|                     and proc.model not in (True, False, None) | ||||
|                 ): | ||||
|                     proc.model.finish_update(sgd) | ||||
|         return losses | ||||
| 
 | ||||
|  | @ -999,19 +992,7 @@ class Language: | |||
|         """ | ||||
|         if len(examples) == 0: | ||||
|             return | ||||
|         if not isinstance(examples, IterableInstance): | ||||
|             raise TypeError( | ||||
|                 Errors.E978.format( | ||||
|                     name="language", method="rehearse", types=type(examples) | ||||
|                 ) | ||||
|             ) | ||||
|         wrong_types = set([type(eg) for eg in examples if not isinstance(eg, Example)]) | ||||
|         if wrong_types: | ||||
|             raise TypeError( | ||||
|                 Errors.E978.format( | ||||
|                     name="language", method="rehearse", types=wrong_types | ||||
|                 ) | ||||
|             ) | ||||
|         validate_examples(examples, "Language.rehearse") | ||||
|         if sgd is None: | ||||
|             if self._optimizer is None: | ||||
|                 self._optimizer = create_default_optimizer() | ||||
|  | @ -1060,7 +1041,15 @@ class Language: | |||
|         if get_examples is None: | ||||
|             get_examples = lambda: [] | ||||
|         else:  # Populate vocab | ||||
|             if not hasattr(get_examples, "__call__"): | ||||
|                 err = Errors.E930.format(name="Language", obj=type(get_examples)) | ||||
|                 raise ValueError(err) | ||||
|             for example in get_examples(): | ||||
|                 if not isinstance(example, Example): | ||||
|                     err = Errors.E978.format( | ||||
|                         name="Language.begin_training", types=type(example) | ||||
|                     ) | ||||
|                     raise ValueError(err) | ||||
|                 for word in [t.text for t in example.reference]: | ||||
|                     _ = self.vocab[word]  # noqa: F841 | ||||
|         if device >= 0:  # TODO: do we need this here? | ||||
|  | @ -1133,17 +1122,7 @@ class Language: | |||
| 
 | ||||
|         DOCS: https://spacy.io/api/language#evaluate | ||||
|         """ | ||||
|         if not isinstance(examples, IterableInstance): | ||||
|             err = Errors.E978.format( | ||||
|                 name="language", method="evaluate", types=type(examples) | ||||
|             ) | ||||
|             raise TypeError(err) | ||||
|         wrong_types = set([type(eg) for eg in examples if not isinstance(eg, Example)]) | ||||
|         if wrong_types: | ||||
|             err = Errors.E978.format( | ||||
|                 name="language", method="evaluate", types=wrong_types | ||||
|             ) | ||||
|             raise TypeError(err) | ||||
|         validate_examples(examples, "Language.evaluate") | ||||
|         if component_cfg is None: | ||||
|             component_cfg = {} | ||||
|         if scorer_cfg is None: | ||||
|  | @ -1663,7 +1642,7 @@ def _fix_pretrained_vectors_name(nlp: Language) -> None: | |||
|     else: | ||||
|         raise ValueError(Errors.E092) | ||||
|     for name, proc in nlp.pipeline: | ||||
|         if not hasattr(proc, "cfg"): | ||||
|         if not hasattr(proc, "cfg") or not isinstance(proc.cfg, dict): | ||||
|             continue | ||||
|         proc.cfg.setdefault("deprecation_fixes", {}) | ||||
|         proc.cfg["deprecation_fixes"]["vectors_name"] = nlp.vocab.vectors.name | ||||
|  |  | |||
|  | @ -9,6 +9,7 @@ from .functions import merge_subtokens | |||
| from ..language import Language | ||||
| from ._parser_internals import nonproj | ||||
| from ..scorer import Scorer | ||||
| from ..gold import validate_examples | ||||
| 
 | ||||
| 
 | ||||
| default_model_config = """ | ||||
|  | @ -147,6 +148,7 @@ cdef class DependencyParser(Parser): | |||
| 
 | ||||
|         DOCS: https://spacy.io/api/dependencyparser#score | ||||
|         """ | ||||
|         validate_examples(examples, "DependencyParser.score") | ||||
|         def dep_getter(token, attr): | ||||
|             dep = getattr(token, attr) | ||||
|             dep = token.vocab.strings.as_string(dep).lower() | ||||
|  |  | |||
|  | @ -11,7 +11,7 @@ from ..tokens import Doc | |||
| from .pipe import Pipe, deserialize_config | ||||
| from ..language import Language | ||||
| from ..vocab import Vocab | ||||
| from ..gold import Example | ||||
| from ..gold import Example, validate_examples | ||||
| from ..errors import Errors, Warnings | ||||
| from .. import util | ||||
| 
 | ||||
|  | @ -142,7 +142,7 @@ class EntityLinker(Pipe): | |||
| 
 | ||||
|     def begin_training( | ||||
|         self, | ||||
|         get_examples: Callable[[], Iterable[Example]] = lambda: [], | ||||
|         get_examples: Callable[[], Iterable[Example]], | ||||
|         *, | ||||
|         pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None, | ||||
|         sgd: Optional[Optimizer] = None, | ||||
|  | @ -197,14 +197,9 @@ class EntityLinker(Pipe): | |||
|         losses.setdefault(self.name, 0.0) | ||||
|         if not examples: | ||||
|             return losses | ||||
|         validate_examples(examples, "EntityLinker.update") | ||||
|         sentence_docs = [] | ||||
|         try: | ||||
|             docs = [eg.predicted for eg in examples] | ||||
|         except AttributeError: | ||||
|             types = set([type(eg) for eg in examples]) | ||||
|             raise TypeError( | ||||
|                 Errors.E978.format(name="EntityLinker", method="update", types=types) | ||||
|             ) from None | ||||
|         docs = [eg.predicted for eg in examples] | ||||
|         if set_annotations: | ||||
|             # This seems simpler than other ways to get that exact output -- but | ||||
|             # it does run the model twice :( | ||||
|  | @ -250,6 +245,7 @@ class EntityLinker(Pipe): | |||
|         return losses | ||||
| 
 | ||||
|     def get_loss(self, examples: Iterable[Example], sentence_encodings): | ||||
|         validate_examples(examples, "EntityLinker.get_loss") | ||||
|         entity_encodings = [] | ||||
|         for eg in examples: | ||||
|             kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True) | ||||
|  |  | |||
|  | @ -9,6 +9,7 @@ from ..util import ensure_path, to_disk, from_disk | |||
| from ..tokens import Doc, Span | ||||
| from ..matcher import Matcher, PhraseMatcher | ||||
| from ..scorer import Scorer | ||||
| from ..gold import validate_examples | ||||
| 
 | ||||
| 
 | ||||
| DEFAULT_ENT_ID_SEP = "||" | ||||
|  | @ -312,6 +313,7 @@ class EntityRuler: | |||
|         return label | ||||
| 
 | ||||
|     def score(self, examples, **kwargs): | ||||
|         validate_examples(examples, "EntityRuler.score") | ||||
|         return Scorer.score_spans(examples, "ents", **kwargs) | ||||
| 
 | ||||
|     def from_bytes( | ||||
|  |  | |||
|  | @ -1,5 +1,4 @@ | |||
| from typing import Optional, List, Dict, Any | ||||
| 
 | ||||
| from thinc.api import Model | ||||
| 
 | ||||
| from .pipe import Pipe | ||||
|  | @ -9,6 +8,7 @@ from ..lookups import Lookups, load_lookups | |||
| from ..scorer import Scorer | ||||
| from ..tokens import Doc, Token | ||||
| from ..vocab import Vocab | ||||
| from ..gold import validate_examples | ||||
| from .. import util | ||||
| 
 | ||||
| 
 | ||||
|  | @ -135,10 +135,10 @@ class Lemmatizer(Pipe): | |||
|         elif self.mode == "rule": | ||||
|             self.lemmatize = self.rule_lemmatize | ||||
|         else: | ||||
|             try: | ||||
|                 self.lemmatize = getattr(self, f"{self.mode}_lemmatize") | ||||
|             except AttributeError: | ||||
|             mode_attr = f"{self.mode}_lemmatize" | ||||
|             if not hasattr(self, mode_attr): | ||||
|                 raise ValueError(Errors.E1003.format(mode=mode)) | ||||
|             self.lemmatize = getattr(self, mode_attr) | ||||
|         self.cache = {} | ||||
| 
 | ||||
|     @property | ||||
|  | @ -271,6 +271,7 @@ class Lemmatizer(Pipe): | |||
| 
 | ||||
|         DOCS: https://spacy.io/api/lemmatizer#score | ||||
|         """ | ||||
|         validate_examples(examples, "Lemmatizer.score") | ||||
|         return Scorer.score_token_attr(examples, "lemma", **kwargs) | ||||
| 
 | ||||
|     def to_disk(self, path, *, exclude=tuple()): | ||||
|  |  | |||
|  | @ -6,15 +6,16 @@ from thinc.api import SequenceCategoricalCrossentropy, Model, Config | |||
| from ..tokens.doc cimport Doc | ||||
| from ..vocab cimport Vocab | ||||
| from ..morphology cimport Morphology | ||||
| 
 | ||||
| from ..parts_of_speech import IDS as POS_IDS | ||||
| from ..symbols import POS | ||||
| 
 | ||||
| from ..language import Language | ||||
| from ..errors import Errors | ||||
| from .pipe import deserialize_config | ||||
| from .tagger import Tagger | ||||
| from .. import util | ||||
| from ..scorer import Scorer | ||||
| from ..gold import validate_examples | ||||
| 
 | ||||
| 
 | ||||
| default_model_config = """ | ||||
|  | @ -126,7 +127,7 @@ class Morphologizer(Tagger): | |||
|             self.cfg["labels_pos"][norm_label] = POS_IDS[pos] | ||||
|         return 1 | ||||
| 
 | ||||
|     def begin_training(self, get_examples=lambda: [], *, pipeline=None, sgd=None): | ||||
|     def begin_training(self, get_examples, *, pipeline=None, sgd=None): | ||||
|         """Initialize the pipe for training, using data examples if available. | ||||
| 
 | ||||
|         get_examples (Callable[[], Iterable[Example]]): Optional function that | ||||
|  | @ -140,6 +141,9 @@ class Morphologizer(Tagger): | |||
| 
 | ||||
|         DOCS: https://spacy.io/api/morphologizer#begin_training | ||||
|         """ | ||||
|         if not hasattr(get_examples, "__call__"): | ||||
|             err = Errors.E930.format(name="Morphologizer", obj=type(get_examples)) | ||||
|             raise ValueError(err) | ||||
|         for example in get_examples(): | ||||
|             for i, token in enumerate(example.reference): | ||||
|                 pos = token.pos_ | ||||
|  | @ -192,6 +196,7 @@ class Morphologizer(Tagger): | |||
| 
 | ||||
|         DOCS: https://spacy.io/api/morphologizer#get_loss | ||||
|         """ | ||||
|         validate_examples(examples, "Morphologizer.get_loss") | ||||
|         loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False) | ||||
|         truths = [] | ||||
|         for eg in examples: | ||||
|  | @ -228,6 +233,7 @@ class Morphologizer(Tagger): | |||
| 
 | ||||
|         DOCS: https://spacy.io/api/morphologizer#score | ||||
|         """ | ||||
|         validate_examples(examples, "Morphologizer.score") | ||||
|         results = {} | ||||
|         results.update(Scorer.score_token_attr(examples, "pos", **kwargs)) | ||||
|         results.update(Scorer.score_token_attr(examples, "morph", **kwargs)) | ||||
|  |  | |||
|  | @ -8,6 +8,7 @@ from ..tokens.doc cimport Doc | |||
| 
 | ||||
| from .pipe import Pipe | ||||
| from .tagger import Tagger | ||||
| from ..gold import validate_examples | ||||
| from ..language import Language | ||||
| from ._parser_internals import nonproj | ||||
| from ..attrs import POS, ID | ||||
|  | @ -80,10 +81,11 @@ class MultitaskObjective(Tagger): | |||
|     def set_annotations(self, docs, dep_ids): | ||||
|         pass | ||||
| 
 | ||||
|     def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None): | ||||
|         gold_examples = nonproj.preprocess_training_data(get_examples()) | ||||
|         # for raw_text, doc_annot in gold_tuples: | ||||
|         for example in gold_examples: | ||||
|     def begin_training(self, get_examples, pipeline=None, sgd=None): | ||||
|         if not hasattr(get_examples, "__call__"): | ||||
|             err = Errors.E930.format(name="MultitaskObjective", obj=type(get_examples)) | ||||
|             raise ValueError(err) | ||||
|         for example in get_examples(): | ||||
|             for token in example.y: | ||||
|                 label = self.make_label(token) | ||||
|                 if label is not None and label not in self.labels: | ||||
|  | @ -175,7 +177,7 @@ class ClozeMultitask(Pipe): | |||
|     def set_annotations(self, docs, dep_ids): | ||||
|         pass | ||||
| 
 | ||||
|     def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None): | ||||
|     def begin_training(self, get_examples, pipeline=None, sgd=None): | ||||
|         self.model.initialize() | ||||
|         X = self.model.ops.alloc((5, self.model.get_ref("tok2vec").get_dim("nO"))) | ||||
|         self.model.output_layer.begin_training(X) | ||||
|  | @ -189,6 +191,7 @@ class ClozeMultitask(Pipe): | |||
|         return tokvecs, vectors | ||||
| 
 | ||||
|     def get_loss(self, examples, vectors, prediction): | ||||
|         validate_examples(examples, "ClozeMultitask.get_loss") | ||||
|         # The simplest way to implement this would be to vstack the | ||||
|         # token.vector values, but that's a bit inefficient, especially on GPU. | ||||
|         # Instead we fetch the index into the vectors table for each of our tokens, | ||||
|  | @ -206,18 +209,16 @@ class ClozeMultitask(Pipe): | |||
|         if losses is not None and self.name not in losses: | ||||
|             losses[self.name] = 0. | ||||
|         set_dropout_rate(self.model, drop) | ||||
|         try: | ||||
|             predictions, bp_predictions = self.model.begin_update([eg.predicted for eg in examples]) | ||||
|         except AttributeError: | ||||
|             types = set([type(eg) for eg in examples]) | ||||
|             raise TypeError(Errors.E978.format(name="ClozeMultitask", method="rehearse", types=types)) from None | ||||
|         validate_examples(examples, "ClozeMultitask.rehearse") | ||||
|         docs = [eg.predicted for eg in examples] | ||||
|         predictions, bp_predictions = self.model.begin_update() | ||||
|         loss, d_predictions = self.get_loss(examples, self.vocab.vectors.data, predictions) | ||||
|         bp_predictions(d_predictions) | ||||
|         if sgd is not None: | ||||
|             self.model.finish_update(sgd) | ||||
| 
 | ||||
|         if losses is not None: | ||||
|             losses[self.name] += loss | ||||
|         return losses | ||||
| 
 | ||||
|     def add_label(self, label): | ||||
|         raise NotImplementedError | ||||
|  |  | |||
|  | @ -7,6 +7,7 @@ from ._parser_internals.ner cimport BiluoPushDown | |||
| 
 | ||||
| from ..language import Language | ||||
| from ..scorer import Scorer | ||||
| from ..gold import validate_examples | ||||
| 
 | ||||
| 
 | ||||
| default_model_config = """ | ||||
|  | @ -50,7 +51,7 @@ def make_ner( | |||
| ): | ||||
|     """Create a transition-based EntityRecognizer component. The entity recognizer | ||||
|     identifies non-overlapping labelled spans of tokens. | ||||
|      | ||||
| 
 | ||||
|     The transition-based algorithm used encodes certain assumptions that are | ||||
|     effective for "traditional" named entity recognition tasks, but may not be | ||||
|     a good fit for every span identification problem. Specifically, the loss | ||||
|  | @ -120,4 +121,5 @@ cdef class EntityRecognizer(Parser): | |||
| 
 | ||||
|         DOCS: https://spacy.io/api/entityrecognizer#score | ||||
|         """ | ||||
|         validate_examples(examples, "EntityRecognizer.score") | ||||
|         return Scorer.score_spans(examples, "ents", **kwargs) | ||||
|  |  | |||
|  | @ -1,2 +1,5 @@ | |||
| cdef class Pipe: | ||||
|     cdef public object vocab | ||||
|     cdef public object model | ||||
|     cdef public str name | ||||
|     cdef public object cfg | ||||
|  |  | |||
|  | @ -1,9 +1,10 @@ | |||
| # cython: infer_types=True, profile=True | ||||
| import srsly | ||||
| from thinc.api import set_dropout_rate, Model | ||||
| 
 | ||||
| from ..tokens.doc cimport Doc | ||||
| 
 | ||||
| from ..util import create_default_optimizer | ||||
| from ..gold import validate_examples | ||||
| from ..errors import Errors | ||||
| from .. import util | ||||
| 
 | ||||
|  | @ -16,7 +17,6 @@ cdef class Pipe: | |||
| 
 | ||||
|     DOCS: https://spacy.io/api/pipe | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, vocab, model, name, **cfg): | ||||
|         """Initialize a pipeline component. | ||||
| 
 | ||||
|  | @ -27,7 +27,10 @@ cdef class Pipe: | |||
| 
 | ||||
|         DOCS: https://spacy.io/api/pipe#init | ||||
|         """ | ||||
|         raise NotImplementedError | ||||
|         self.vocab = vocab | ||||
|         self.model = model | ||||
|         self.name = name | ||||
|         self.cfg = dict(cfg) | ||||
| 
 | ||||
|     def __call__(self, Doc doc): | ||||
|         """Apply the pipe to one document. The document is modified in place, | ||||
|  | @ -68,7 +71,7 @@ cdef class Pipe: | |||
| 
 | ||||
|         DOCS: https://spacy.io/api/pipe#predict | ||||
|         """ | ||||
|         raise NotImplementedError | ||||
|         raise NotImplementedError(Errors.E931.format(method="predict", name=self.name)) | ||||
| 
 | ||||
|     def set_annotations(self, docs, scores): | ||||
|         """Modify a batch of documents, using pre-computed scores. | ||||
|  | @ -78,7 +81,43 @@ cdef class Pipe: | |||
| 
 | ||||
|         DOCS: https://spacy.io/api/pipe#set_annotations | ||||
|         """ | ||||
|         raise NotImplementedError | ||||
|         raise NotImplementedError(Errors.E931.format(method="set_annotations", name=self.name)) | ||||
| 
 | ||||
|     def update(self, examples, *, drop=0.0, set_annotations=False, sgd=None, losses=None): | ||||
|         """Learn from a batch of documents and gold-standard information, | ||||
|         updating the pipe's model. Delegates to predict and get_loss. | ||||
| 
 | ||||
|         examples (Iterable[Example]): A batch of Example objects. | ||||
|         drop (float): The dropout rate. | ||||
|         set_annotations (bool): Whether or not to update the Example objects | ||||
|             with the predictions. | ||||
|         sgd (thinc.api.Optimizer): The optimizer. | ||||
|         losses (Dict[str, float]): Optional record of the loss during training. | ||||
|             Updated using the component name as the key. | ||||
|         RETURNS (Dict[str, float]): The updated losses dictionary. | ||||
| 
 | ||||
|         DOCS: https://spacy.io/api/pipe#update | ||||
|         """ | ||||
|         if losses is None: | ||||
|             losses = {} | ||||
|         if not hasattr(self, "model") or self.model in (None, True, False): | ||||
|             return losses | ||||
|         losses.setdefault(self.name, 0.0) | ||||
|         validate_examples(examples, "Pipe.update") | ||||
|         if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples): | ||||
|             # Handle cases where there are no tokens in any docs. | ||||
|             return | ||||
|         set_dropout_rate(self.model, drop) | ||||
|         scores, bp_scores = self.model.begin_update([eg.predicted for eg in examples]) | ||||
|         loss, d_scores = self.get_loss(examples, scores) | ||||
|         bp_scores(d_scores) | ||||
|         if sgd not in (None, False): | ||||
|             self.model.finish_update(sgd) | ||||
|         losses[self.name] += loss | ||||
|         if set_annotations: | ||||
|             docs = [eg.predicted for eg in examples] | ||||
|             self.set_annotations(docs, scores=scores) | ||||
|         return losses | ||||
| 
 | ||||
|     def rehearse(self, examples, *, sgd=None, losses=None, **config): | ||||
|         """Perform a "rehearsal" update from a batch of data. Rehearsal updates | ||||
|  | @ -107,7 +146,7 @@ cdef class Pipe: | |||
| 
 | ||||
|         DOCS: https://spacy.io/api/pipe#get_loss | ||||
|         """ | ||||
|         raise NotImplementedError | ||||
|         raise NotImplementedError(Errors.E931.format(method="get_loss", name=self.name)) | ||||
| 
 | ||||
|     def add_label(self, label): | ||||
|         """Add an output label, to be predicted by the model. It's possible to | ||||
|  | @ -119,7 +158,7 @@ cdef class Pipe: | |||
| 
 | ||||
|         DOCS: https://spacy.io/api/pipe#add_label | ||||
|         """ | ||||
|         raise NotImplementedError | ||||
|         raise NotImplementedError(Errors.E931.format(method="add_label", name=self.name)) | ||||
| 
 | ||||
|     def create_optimizer(self): | ||||
|         """Create an optimizer for the pipeline component. | ||||
|  | @ -128,9 +167,9 @@ cdef class Pipe: | |||
| 
 | ||||
|         DOCS: https://spacy.io/api/pipe#create_optimizer | ||||
|         """ | ||||
|         return create_default_optimizer() | ||||
|         return util.create_default_optimizer() | ||||
| 
 | ||||
|     def begin_training(self, get_examples=lambda: [], *, pipeline=None, sgd=None): | ||||
|     def begin_training(self, get_examples, *, pipeline=None, sgd=None): | ||||
|         """Initialize the pipe for training, using data examples if available. | ||||
| 
 | ||||
|         get_examples (Callable[[], Iterable[Example]]): Optional function that | ||||
|  |  | |||
|  | @ -7,6 +7,7 @@ from ..tokens.doc cimport Doc | |||
| from .pipe import Pipe | ||||
| from ..language import Language | ||||
| from ..scorer import Scorer | ||||
| from ..gold import validate_examples | ||||
| from .. import util | ||||
| 
 | ||||
| 
 | ||||
|  | @ -58,7 +59,7 @@ class Sentencizer(Pipe): | |||
|         else: | ||||
|             self.punct_chars = set(self.default_punct_chars) | ||||
| 
 | ||||
|     def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None): | ||||
|     def begin_training(self, get_examples, pipeline=None, sgd=None): | ||||
|         pass | ||||
| 
 | ||||
|     def __call__(self, doc): | ||||
|  | @ -158,6 +159,7 @@ class Sentencizer(Pipe): | |||
| 
 | ||||
|         DOCS: https://spacy.io/api/sentencizer#score | ||||
|         """ | ||||
|         validate_examples(examples, "Sentencizer.score") | ||||
|         results = Scorer.score_spans(examples, "sents", **kwargs) | ||||
|         del results["sents_per_type"] | ||||
|         return results | ||||
|  |  | |||
|  | @ -9,6 +9,7 @@ from .tagger import Tagger | |||
| from ..language import Language | ||||
| from ..errors import Errors | ||||
| from ..scorer import Scorer | ||||
| from ..gold import validate_examples | ||||
| from .. import util | ||||
| 
 | ||||
| 
 | ||||
|  | @ -102,6 +103,7 @@ class SentenceRecognizer(Tagger): | |||
| 
 | ||||
|         DOCS: https://spacy.io/api/sentencerecognizer#get_loss | ||||
|         """ | ||||
|         validate_examples(examples, "SentenceRecognizer.get_loss") | ||||
|         labels = self.labels | ||||
|         loss_func = SequenceCategoricalCrossentropy(names=labels, normalize=False) | ||||
|         truths = [] | ||||
|  | @ -121,7 +123,7 @@ class SentenceRecognizer(Tagger): | |||
|             raise ValueError("nan value when computing loss") | ||||
|         return float(loss), d_scores | ||||
| 
 | ||||
|     def begin_training(self, get_examples=lambda: [], *, pipeline=None, sgd=None): | ||||
|     def begin_training(self, get_examples, *, pipeline=None, sgd=None): | ||||
|         """Initialize the pipe for training, using data examples if available. | ||||
| 
 | ||||
|         get_examples (Callable[[], Iterable[Example]]): Optional function that | ||||
|  | @ -151,6 +153,7 @@ class SentenceRecognizer(Tagger): | |||
|         RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_spans. | ||||
|         DOCS: https://spacy.io/api/sentencerecognizer#score | ||||
|         """ | ||||
|         validate_examples(examples, "SentenceRecognizer.score") | ||||
|         results = Scorer.score_spans(examples, "sents", **kwargs) | ||||
|         del results["sents_per_type"] | ||||
|         return results | ||||
|  |  | |||
|  | @ -1,4 +1,4 @@ | |||
| from typing import List, Iterable, Optional, Dict, Tuple, Callable | ||||
| from typing import List, Iterable, Optional, Dict, Tuple, Callable, Set | ||||
| from thinc.types import Floats2d | ||||
| from thinc.api import SequenceCategoricalCrossentropy, set_dropout_rate, Model | ||||
| from thinc.api import Optimizer, Config | ||||
|  | @ -6,6 +6,7 @@ from thinc.util import to_numpy | |||
| 
 | ||||
| from ..errors import Errors | ||||
| from ..gold import Example, spans_from_biluo_tags, iob_to_biluo, biluo_to_iob | ||||
| from ..gold import validate_examples | ||||
| from ..tokens import Doc | ||||
| from ..language import Language | ||||
| from ..vocab import Vocab | ||||
|  | @ -127,6 +128,7 @@ class SimpleNER(Pipe): | |||
|         if losses is None: | ||||
|             losses = {} | ||||
|         losses.setdefault("ner", 0.0) | ||||
|         validate_examples(examples, "SimpleNER.update") | ||||
|         if not any(_has_ner(eg) for eg in examples): | ||||
|             return losses | ||||
|         docs = [eg.predicted for eg in examples] | ||||
|  | @ -142,6 +144,7 @@ class SimpleNER(Pipe): | |||
|         return losses | ||||
| 
 | ||||
|     def get_loss(self, examples: List[Example], scores) -> Tuple[List[Floats2d], float]: | ||||
|         validate_examples(examples, "SimpleNER.get_loss") | ||||
|         truths = [] | ||||
|         for eg in examples: | ||||
|             tags = eg.get_aligned_ner() | ||||
|  | @ -161,14 +164,17 @@ class SimpleNER(Pipe): | |||
| 
 | ||||
|     def begin_training( | ||||
|         self, | ||||
|         get_examples: Callable, | ||||
|         get_examples: Callable[[], Iterable[Example]], | ||||
|         pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None, | ||||
|         sgd: Optional[Optimizer] = None, | ||||
|     ): | ||||
|         all_labels = set() | ||||
|         if not hasattr(get_examples, "__call__"): | ||||
|             gold_tuples = get_examples | ||||
|             get_examples = lambda: gold_tuples | ||||
|         for label in _get_labels(get_examples()): | ||||
|             err = Errors.E930.format(name="SimpleNER", obj=type(get_examples)) | ||||
|             raise ValueError(err) | ||||
|         for example in get_examples(): | ||||
|             all_labels.update(_get_labels(example)) | ||||
|         for label in sorted(all_labels): | ||||
|             self.add_label(label) | ||||
|         labels = self.labels | ||||
|         n_actions = self.model.attrs["get_num_actions"](len(labels)) | ||||
|  | @ -185,6 +191,7 @@ class SimpleNER(Pipe): | |||
|         pass | ||||
| 
 | ||||
|     def score(self, examples, **kwargs): | ||||
|         validate_examples(examples, "SimpleNER.score") | ||||
|         return Scorer.score_spans(examples, "ents", **kwargs) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -196,10 +203,9 @@ def _has_ner(example: Example) -> bool: | |||
|         return False | ||||
| 
 | ||||
| 
 | ||||
| def _get_labels(examples: List[Example]) -> List[str]: | ||||
| def _get_labels(example: Example) -> Set[str]: | ||||
|     labels = set() | ||||
|     for eg in examples: | ||||
|         for ner_tag in eg.get_aligned("ENT_TYPE", as_string=True): | ||||
|             if ner_tag != "O" and ner_tag != "-": | ||||
|                 labels.add(ner_tag) | ||||
|     return list(sorted(labels)) | ||||
|     for ner_tag in example.get_aligned("ENT_TYPE", as_string=True): | ||||
|         if ner_tag != "O" and ner_tag != "-": | ||||
|             labels.add(ner_tag) | ||||
|     return labels | ||||
|  |  | |||
|  | @ -16,6 +16,7 @@ from ..attrs import POS, ID | |||
| from ..parts_of_speech import X | ||||
| from ..errors import Errors, TempErrors, Warnings | ||||
| from ..scorer import Scorer | ||||
| from ..gold import validate_examples | ||||
| from .. import util | ||||
| 
 | ||||
| 
 | ||||
|  | @ -187,19 +188,15 @@ class Tagger(Pipe): | |||
|         if losses is None: | ||||
|             losses = {} | ||||
|         losses.setdefault(self.name, 0.0) | ||||
|         try: | ||||
|             if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples): | ||||
|                 # Handle cases where there are no tokens in any docs. | ||||
|                 return | ||||
|         except AttributeError: | ||||
|             types = set([type(eg) for eg in examples]) | ||||
|             raise TypeError(Errors.E978.format(name="Tagger", method="update", types=types)) from None | ||||
|         validate_examples(examples, "Tagger.update") | ||||
|         if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples): | ||||
|             # Handle cases where there are no tokens in any docs. | ||||
|             return | ||||
|         set_dropout_rate(self.model, drop) | ||||
|         tag_scores, bp_tag_scores = self.model.begin_update( | ||||
|             [eg.predicted for eg in examples]) | ||||
|         tag_scores, bp_tag_scores = self.model.begin_update([eg.predicted for eg in examples]) | ||||
|         for sc in tag_scores: | ||||
|             if self.model.ops.xp.isnan(sc.sum()): | ||||
|                 raise ValueError("nan value in scores") | ||||
|                 raise ValueError(Errors.E940) | ||||
|         loss, d_tag_scores = self.get_loss(examples, tag_scores) | ||||
|         bp_tag_scores(d_tag_scores) | ||||
|         if sgd not in (None, False): | ||||
|  | @ -226,11 +223,8 @@ class Tagger(Pipe): | |||
| 
 | ||||
|         DOCS: https://spacy.io/api/tagger#rehearse | ||||
|         """ | ||||
|         try: | ||||
|             docs = [eg.predicted for eg in examples] | ||||
|         except AttributeError: | ||||
|             types = set([type(eg) for eg in examples]) | ||||
|             raise TypeError(Errors.E978.format(name="Tagger", method="rehearse", types=types)) from None | ||||
|         validate_examples(examples, "Tagger.rehearse") | ||||
|         docs = [eg.predicted for eg in examples] | ||||
|         if self._rehearsal_model is None: | ||||
|             return | ||||
|         if not any(len(doc) for doc in docs): | ||||
|  | @ -256,6 +250,7 @@ class Tagger(Pipe): | |||
| 
 | ||||
|         DOCS: https://spacy.io/api/tagger#get_loss | ||||
|         """ | ||||
|         validate_examples(examples, "Tagger.get_loss") | ||||
|         loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False) | ||||
|         truths = [eg.get_aligned("TAG", as_string=True) for eg in examples] | ||||
|         d_scores, loss = loss_func(scores, truths) | ||||
|  | @ -263,7 +258,7 @@ class Tagger(Pipe): | |||
|             raise ValueError("nan value when computing loss") | ||||
|         return float(loss), d_scores | ||||
| 
 | ||||
|     def begin_training(self, get_examples=lambda: [], *, pipeline=None, sgd=None): | ||||
|     def begin_training(self, get_examples, *, pipeline=None, sgd=None): | ||||
|         """Initialize the pipe for training, using data examples if available. | ||||
| 
 | ||||
|         get_examples (Callable[[], Iterable[Example]]): Optional function that | ||||
|  | @ -277,13 +272,12 @@ class Tagger(Pipe): | |||
| 
 | ||||
|         DOCS: https://spacy.io/api/tagger#begin_training | ||||
|         """ | ||||
|         if not hasattr(get_examples, "__call__"): | ||||
|             err = Errors.E930.format(name="Tagger", obj=type(get_examples)) | ||||
|             raise ValueError(err) | ||||
|         tags = set() | ||||
|         for example in get_examples(): | ||||
|             try: | ||||
|                 y = example.y | ||||
|             except AttributeError: | ||||
|                 raise TypeError(Errors.E978.format(name="Tagger", method="begin_training", types=type(example))) from None | ||||
|             for token in y: | ||||
|             for token in example.y: | ||||
|                 tags.add(token.tag_) | ||||
|         for tag in sorted(tags): | ||||
|             self.add_label(tag) | ||||
|  | @ -318,6 +312,7 @@ class Tagger(Pipe): | |||
| 
 | ||||
|         DOCS: https://spacy.io/api/tagger#score | ||||
|         """ | ||||
|         validate_examples(examples, "Tagger.score") | ||||
|         return Scorer.score_token_attr(examples, "tag", **kwargs) | ||||
| 
 | ||||
|     def to_bytes(self, *, exclude=tuple()): | ||||
|  |  | |||
|  | @ -5,7 +5,7 @@ import numpy | |||
| 
 | ||||
| from .pipe import Pipe | ||||
| from ..language import Language | ||||
| from ..gold import Example | ||||
| from ..gold import Example, validate_examples | ||||
| from ..errors import Errors | ||||
| from ..scorer import Scorer | ||||
| from .. import util | ||||
|  | @ -209,15 +209,10 @@ class TextCategorizer(Pipe): | |||
|         if losses is None: | ||||
|             losses = {} | ||||
|         losses.setdefault(self.name, 0.0) | ||||
|         try: | ||||
|             if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples): | ||||
|                 # Handle cases where there are no tokens in any docs. | ||||
|                 return losses | ||||
|         except AttributeError: | ||||
|             types = set([type(eg) for eg in examples]) | ||||
|             raise TypeError( | ||||
|                 Errors.E978.format(name="TextCategorizer", method="update", types=types) | ||||
|             ) from None | ||||
|         validate_examples(examples, "TextCategorizer.update") | ||||
|         if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples): | ||||
|             # Handle cases where there are no tokens in any docs. | ||||
|             return losses | ||||
|         set_dropout_rate(self.model, drop) | ||||
|         scores, bp_scores = self.model.begin_update([eg.predicted for eg in examples]) | ||||
|         loss, d_scores = self.get_loss(examples, scores) | ||||
|  | @ -252,19 +247,12 @@ class TextCategorizer(Pipe): | |||
| 
 | ||||
|         DOCS: https://spacy.io/api/textcategorizer#rehearse | ||||
|         """ | ||||
| 
 | ||||
|         if losses is not None: | ||||
|             losses.setdefault(self.name, 0.0) | ||||
|         if self._rehearsal_model is None: | ||||
|             return losses | ||||
|         try: | ||||
|             docs = [eg.predicted for eg in examples] | ||||
|         except AttributeError: | ||||
|             types = set([type(eg) for eg in examples]) | ||||
|             err = Errors.E978.format( | ||||
|                 name="TextCategorizer", method="rehearse", types=types | ||||
|             ) | ||||
|             raise TypeError(err) from None | ||||
|         validate_examples(examples, "TextCategorizer.rehearse") | ||||
|         docs = [eg.predicted for eg in examples] | ||||
|         if not any(len(doc) for doc in docs): | ||||
|             # Handle cases where there are no tokens in any docs. | ||||
|             return losses | ||||
|  | @ -303,6 +291,7 @@ class TextCategorizer(Pipe): | |||
| 
 | ||||
|         DOCS: https://spacy.io/api/textcategorizer#get_loss | ||||
|         """ | ||||
|         validate_examples(examples, "TextCategorizer.get_loss") | ||||
|         truths, not_missing = self._examples_to_truth(examples) | ||||
|         not_missing = self.model.ops.asarray(not_missing) | ||||
|         d_scores = (scores - truths) / scores.shape[0] | ||||
|  | @ -338,7 +327,7 @@ class TextCategorizer(Pipe): | |||
| 
 | ||||
|     def begin_training( | ||||
|         self, | ||||
|         get_examples: Callable[[], Iterable[Example]] = lambda: [], | ||||
|         get_examples: Callable[[], Iterable[Example]], | ||||
|         *, | ||||
|         pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None, | ||||
|         sgd: Optional[Optimizer] = None, | ||||
|  | @ -356,21 +345,20 @@ class TextCategorizer(Pipe): | |||
| 
 | ||||
|         DOCS: https://spacy.io/api/textcategorizer#begin_training | ||||
|         """ | ||||
|         # TODO: begin_training is not guaranteed to see all data / labels ? | ||||
|         examples = list(get_examples()) | ||||
|         for example in examples: | ||||
|             try: | ||||
|                 y = example.y | ||||
|             except AttributeError: | ||||
|                 err = Errors.E978.format( | ||||
|                     name="TextCategorizer", method="update", types=type(example) | ||||
|                 ) | ||||
|                 raise TypeError(err) from None | ||||
|             for cat in y.cats: | ||||
|         if not hasattr(get_examples, "__call__"): | ||||
|             err = Errors.E930.format(name="TextCategorizer", obj=type(get_examples)) | ||||
|             raise ValueError(err) | ||||
|         subbatch = []  # Select a subbatch of examples to initialize the model | ||||
|         for example in get_examples(): | ||||
|             if len(subbatch) < 2: | ||||
|                 subbatch.append(example) | ||||
|             for cat in example.y.cats: | ||||
|                 self.add_label(cat) | ||||
|         self.require_labels() | ||||
|         docs = [Doc(self.vocab, words=["hello"])] | ||||
|         truths, _ = self._examples_to_truth(examples) | ||||
|         docs = [eg.reference for eg in subbatch] | ||||
|         if not docs:  # need at least one doc | ||||
|             docs = [Doc(self.vocab, words=["hello"])] | ||||
|         truths, _ = self._examples_to_truth(subbatch) | ||||
|         self.set_output(len(self.labels)) | ||||
|         self.model.initialize(X=docs, Y=truths) | ||||
|         if sgd is None: | ||||
|  | @ -392,6 +380,7 @@ class TextCategorizer(Pipe): | |||
| 
 | ||||
|         DOCS: https://spacy.io/api/textcategorizer#score | ||||
|         """ | ||||
|         validate_examples(examples, "TextCategorizer.score") | ||||
|         return Scorer.score_cats( | ||||
|             examples, | ||||
|             "cats", | ||||
|  |  | |||
|  | @ -2,7 +2,7 @@ from typing import Iterator, Sequence, Iterable, Optional, Dict, Callable, List, | |||
| from thinc.api import Model, set_dropout_rate, Optimizer, Config | ||||
| 
 | ||||
| from .pipe import Pipe | ||||
| from ..gold import Example | ||||
| from ..gold import Example, validate_examples | ||||
| from ..tokens import Doc | ||||
| from ..vocab import Vocab | ||||
| from ..language import Language | ||||
|  | @ -166,9 +166,8 @@ class Tok2Vec(Pipe): | |||
|         """ | ||||
|         if losses is None: | ||||
|             losses = {} | ||||
|         validate_examples(examples, "Tok2Vec.update") | ||||
|         docs = [eg.predicted for eg in examples] | ||||
|         if isinstance(docs, Doc): | ||||
|             docs = [docs] | ||||
|         set_dropout_rate(self.model, drop) | ||||
|         tokvecs, bp_tokvecs = self.model.begin_update(docs) | ||||
|         d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs] | ||||
|  | @ -204,7 +203,7 @@ class Tok2Vec(Pipe): | |||
| 
 | ||||
|     def begin_training( | ||||
|         self, | ||||
|         get_examples: Callable[[], Iterable[Example]] = lambda: [], | ||||
|         get_examples: Callable[[], Iterable[Example]], | ||||
|         *, | ||||
|         pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None, | ||||
|         sgd: Optional[Optimizer] = None, | ||||
|  |  | |||
|  | @ -8,11 +8,8 @@ from ..ml.parser_model cimport WeightsC, ActivationsC, SizesC | |||
| 
 | ||||
| 
 | ||||
| cdef class Parser(Pipe): | ||||
|     cdef readonly Vocab vocab | ||||
|     cdef public object model | ||||
|     cdef public object _rehearsal_model | ||||
|     cdef readonly TransitionSystem moves | ||||
|     cdef readonly object cfg | ||||
|     cdef public object _multitasks | ||||
| 
 | ||||
|     cdef void _parseC(self, StateC** states, | ||||
|  |  | |||
|  | @ -8,22 +8,21 @@ from libc.string cimport memset | |||
| from libc.stdlib cimport calloc, free | ||||
| 
 | ||||
| import srsly | ||||
| from thinc.api import set_dropout_rate | ||||
| import numpy.random | ||||
| import numpy | ||||
| import warnings | ||||
| 
 | ||||
| from ._parser_internals.stateclass cimport StateClass | ||||
| from ..ml.parser_model cimport alloc_activations, free_activations | ||||
| from ..ml.parser_model cimport predict_states, arg_max_if_valid | ||||
| from ..ml.parser_model cimport WeightsC, ActivationsC, SizesC, cpu_log_loss | ||||
| from ..ml.parser_model cimport get_c_weights, get_c_sizes | ||||
| 
 | ||||
| from ..tokens.doc cimport Doc | ||||
| 
 | ||||
| from ..gold import validate_examples | ||||
| from ..errors import Errors, Warnings | ||||
| from .. import util | ||||
| from ..util import create_default_optimizer | ||||
| 
 | ||||
| from thinc.api import set_dropout_rate | ||||
| import numpy.random | ||||
| import numpy | ||||
| import warnings | ||||
| 
 | ||||
| 
 | ||||
| cdef class Parser(Pipe): | ||||
|  | @ -266,6 +265,7 @@ cdef class Parser(Pipe): | |||
|         if losses is None: | ||||
|             losses = {} | ||||
|         losses.setdefault(self.name, 0.) | ||||
|         validate_examples(examples, "Parser.update") | ||||
|         for multitask in self._multitasks: | ||||
|             multitask.update(examples, drop=drop, sgd=sgd) | ||||
|         n_examples = len([eg for eg in examples if self.moves.has_gold(eg)]) | ||||
|  | @ -329,7 +329,7 @@ cdef class Parser(Pipe): | |||
|         if self._rehearsal_model is None: | ||||
|             return None | ||||
|         losses.setdefault(self.name, 0.) | ||||
| 
 | ||||
|         validate_examples(examples, "Parser.rehearse") | ||||
|         docs = [eg.predicted for eg in examples] | ||||
|         states = self.moves.init_batch(docs) | ||||
|         # This is pretty dirty, but the NER can resize itself in init_batch, | ||||
|  | @ -398,21 +398,18 @@ cdef class Parser(Pipe): | |||
|             losses[self.name] += (d_scores**2).sum() | ||||
|         return d_scores | ||||
| 
 | ||||
|     def create_optimizer(self): | ||||
|         return create_default_optimizer() | ||||
| 
 | ||||
|     def set_output(self, nO): | ||||
|         self.model.attrs["resize_output"](self.model, nO) | ||||
| 
 | ||||
|     def begin_training(self, get_examples, pipeline=None, sgd=None, **kwargs): | ||||
|         if not hasattr(get_examples, "__call__"): | ||||
|             err = Errors.E930.format(name="DependencyParser/EntityRecognizer", obj=type(get_examples)) | ||||
|             raise ValueError(err) | ||||
|         self.cfg.update(kwargs) | ||||
|         lexeme_norms = self.vocab.lookups.get_table("lexeme_norm", {}) | ||||
|         if len(lexeme_norms) == 0 and self.vocab.lang in util.LEXEME_NORM_LANGS: | ||||
|             langs = ", ".join(util.LEXEME_NORM_LANGS) | ||||
|             warnings.warn(Warnings.W033.format(model="parser or NER", langs=langs)) | ||||
|         if not hasattr(get_examples, '__call__'): | ||||
|             gold_tuples = get_examples | ||||
|             get_examples = lambda: gold_tuples | ||||
|         actions = self.moves.get_actions( | ||||
|             examples=get_examples(), | ||||
|             min_freq=self.cfg['min_action_freq'], | ||||
|  |  | |||
|  | @ -18,7 +18,7 @@ def test_doc_add_entities_set_ents_iob(en_vocab): | |||
|     cfg = {"model": DEFAULT_NER_MODEL} | ||||
|     model = registry.make_from_config(cfg, validate=True)["model"] | ||||
|     ner = EntityRecognizer(en_vocab, model, **config) | ||||
|     ner.begin_training([]) | ||||
|     ner.begin_training(lambda: []) | ||||
|     ner(doc) | ||||
|     assert len(list(doc.ents)) == 0 | ||||
|     assert [w.ent_iob_ for w in doc] == (["O"] * len(doc)) | ||||
|  | @ -41,7 +41,7 @@ def test_ents_reset(en_vocab): | |||
|     cfg = {"model": DEFAULT_NER_MODEL} | ||||
|     model = registry.make_from_config(cfg, validate=True)["model"] | ||||
|     ner = EntityRecognizer(en_vocab, model, **config) | ||||
|     ner.begin_training([]) | ||||
|     ner.begin_training(lambda: []) | ||||
|     ner(doc) | ||||
|     assert [t.ent_iob_ for t in doc] == (["O"] * len(doc)) | ||||
|     doc.ents = list(doc.ents) | ||||
|  |  | |||
|  | @ -35,7 +35,7 @@ def test_init_parser(parser): | |||
| def _train_parser(parser): | ||||
|     fix_random_seed(1) | ||||
|     parser.add_label("left") | ||||
|     parser.begin_training([], **parser.cfg) | ||||
|     parser.begin_training(lambda: [], **parser.cfg) | ||||
|     sgd = Adam(0.001) | ||||
| 
 | ||||
|     for i in range(5): | ||||
|  | @ -75,7 +75,7 @@ def test_add_label_deserializes_correctly(): | |||
|     ner1.add_label("C") | ||||
|     ner1.add_label("B") | ||||
|     ner1.add_label("A") | ||||
|     ner1.begin_training([]) | ||||
|     ner1.begin_training(lambda: []) | ||||
|     ner2 = EntityRecognizer(Vocab(), model, **config) | ||||
| 
 | ||||
|     # the second model needs to be resized before we can call from_bytes | ||||
|  |  | |||
|  | @ -28,7 +28,7 @@ def parser(vocab): | |||
|     parser.cfg["hidden_width"] = 32 | ||||
|     # parser.add_label('right') | ||||
|     parser.add_label("left") | ||||
|     parser.begin_training([], **parser.cfg) | ||||
|     parser.begin_training(lambda: [], **parser.cfg) | ||||
|     sgd = Adam(0.001) | ||||
| 
 | ||||
|     for i in range(10): | ||||
|  |  | |||
|  | @ -136,7 +136,7 @@ def test_kb_undefined(nlp): | |||
|     """Test that the EL can't train without defining a KB""" | ||||
|     entity_linker = nlp.add_pipe("entity_linker", config={}) | ||||
|     with pytest.raises(ValueError): | ||||
|         entity_linker.begin_training() | ||||
|         entity_linker.begin_training(lambda: []) | ||||
| 
 | ||||
| 
 | ||||
| def test_kb_empty(nlp): | ||||
|  | @ -145,7 +145,7 @@ def test_kb_empty(nlp): | |||
|     entity_linker = nlp.add_pipe("entity_linker", config=config) | ||||
|     assert len(entity_linker.kb) == 0 | ||||
|     with pytest.raises(ValueError): | ||||
|         entity_linker.begin_training() | ||||
|         entity_linker.begin_training(lambda: []) | ||||
| 
 | ||||
| 
 | ||||
| def test_candidate_generation(nlp): | ||||
|  | @ -249,7 +249,7 @@ def test_preserving_links_asdoc(nlp): | |||
|     ruler.add_patterns(patterns) | ||||
|     el_config = {"kb": {"@assets": "myLocationsKB.v1"}, "incl_prior": False} | ||||
|     el_pipe = nlp.add_pipe("entity_linker", config=el_config, last=True) | ||||
|     el_pipe.begin_training() | ||||
|     el_pipe.begin_training(lambda: []) | ||||
|     el_pipe.incl_context = False | ||||
|     el_pipe.incl_prior = True | ||||
| 
 | ||||
|  |  | |||
|  | @ -54,7 +54,7 @@ def test_textcat_learns_multilabel(): | |||
|     textcat = TextCategorizer(nlp.vocab, width=8) | ||||
|     for letter in letters: | ||||
|         textcat.add_label(letter) | ||||
|     optimizer = textcat.begin_training() | ||||
|     optimizer = textcat.begin_training(lambda: []) | ||||
|     for i in range(30): | ||||
|         losses = {} | ||||
|         examples = [Example.from_dict(doc, {"cats": cats}) for doc, cat in docs] | ||||
|  |  | |||
|  | @ -20,7 +20,7 @@ def test_issue2564(): | |||
|     nlp = Language() | ||||
|     tagger = nlp.add_pipe("tagger") | ||||
|     tagger.add_label("A") | ||||
|     tagger.begin_training() | ||||
|     tagger.begin_training(lambda: []) | ||||
|     doc = nlp("hello world") | ||||
|     assert doc.is_tagged | ||||
|     docs = nlp.pipe(["hello", "world"]) | ||||
|  |  | |||
|  | @ -303,7 +303,7 @@ def test_issue4313(): | |||
|     config = {} | ||||
|     ner = nlp.create_pipe("ner", config=config) | ||||
|     ner.add_label("SOME_LABEL") | ||||
|     ner.begin_training([]) | ||||
|     ner.begin_training(lambda: []) | ||||
|     # add a new label to the doc | ||||
|     doc = nlp("What do you think about Apple ?") | ||||
|     assert len(ner.labels) == 1 | ||||
|  |  | |||
|  | @ -62,7 +62,7 @@ def tagger(): | |||
|     # need to add model for two reasons: | ||||
|     # 1. no model leads to error in serialization, | ||||
|     # 2. the affected line is the one for model serialization | ||||
|     tagger.begin_training(pipeline=nlp.pipeline) | ||||
|     tagger.begin_training(lambda: [], pipeline=nlp.pipeline) | ||||
|     return tagger | ||||
| 
 | ||||
| 
 | ||||
|  | @ -81,7 +81,7 @@ def entity_linker(): | |||
|     # need to add model for two reasons: | ||||
|     # 1. no model leads to error in serialization, | ||||
|     # 2. the affected line is the one for model serialization | ||||
|     entity_linker.begin_training(pipeline=nlp.pipeline) | ||||
|     entity_linker.begin_training(lambda: [], pipeline=nlp.pipeline) | ||||
|     return entity_linker | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -24,6 +24,7 @@ from .util import registry | |||
| from .attrs import intify_attrs | ||||
| from .symbols import ORTH | ||||
| from .scorer import Scorer | ||||
| from .gold import validate_examples | ||||
| 
 | ||||
| 
 | ||||
| cdef class Tokenizer: | ||||
|  | @ -712,6 +713,7 @@ cdef class Tokenizer: | |||
|         return tokens | ||||
| 
 | ||||
|     def score(self, examples, **kwargs): | ||||
|         validate_examples(examples, "Tokenizer.score") | ||||
|         return Scorer.score_tokenization(examples) | ||||
| 
 | ||||
|     def to_disk(self, path, **kwargs): | ||||
|  |  | |||
|  | @ -45,18 +45,12 @@ Create a new pipeline instance. In your application, you would normally use a | |||
| shortcut for this and instantiate the component using its string name and | ||||
| [`nlp.add_pipe`](/api/language#create_pipe). | ||||
| 
 | ||||
| <Infobox variant="danger"> | ||||
| 
 | ||||
| This method needs to be overwritten with your own custom `__init__` method. | ||||
| 
 | ||||
| </Infobox> | ||||
| 
 | ||||
| | Name    | Type                                       | Description                                                                                 | | ||||
| | ------- | ------------------------------------------ | ------------------------------------------------------------------------------------------- | | ||||
| | `vocab` | `Vocab`                                    | The shared vocabulary.                                                                      | | ||||
| | `model` | [`Model`](https://thinc.ai/docs/api-model) | The Thinc [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component.       | | ||||
| | `name`  | str                                        | String name of the component instance. Used to add entries to the `losses` during training. | | ||||
| | `**cfg` |                                            | Additional config parameters and settings.                                                  | | ||||
| | Name    | Type                                       | Description                                                                                                                     | | ||||
| | ------- | ------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------- | | ||||
| | `vocab` | `Vocab`                                    | The shared vocabulary.                                                                                                          | | ||||
| | `model` | [`Model`](https://thinc.ai/docs/api-model) | The Thinc [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component.                                           | | ||||
| | `name`  | str                                        | String name of the component instance. Used to add entries to the `losses` during training.                                     | | ||||
| | `**cfg` |                                            | Additional config parameters and settings. Will be available as the dictionary `Pipe.cfg` and is serialized with the component. | | ||||
| 
 | ||||
| ## Pipe.\_\_call\_\_ {#call tag="method"} | ||||
| 
 | ||||
|  | @ -182,12 +176,6 @@ method. | |||
| Learn from a batch of [`Example`](/api/example) objects containing the | ||||
| predictions and gold-standard annotations, and update the component's model. | ||||
| 
 | ||||
| <Infobox variant="danger"> | ||||
| 
 | ||||
| This method needs to be overwritten with your own custom `update` method. | ||||
| 
 | ||||
| </Infobox> | ||||
| 
 | ||||
| > #### Example | ||||
| > | ||||
| > ```python | ||||
|  | @ -384,6 +372,15 @@ Load the pipe from a bytestring. Modifies the object in place and returns it. | |||
| | `exclude`      | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. | | ||||
| | **RETURNS**    | `Pipe`          | The pipe.                                                                 | | ||||
| 
 | ||||
| ## Attributes {#attributes} | ||||
| 
 | ||||
| | Name    | Type                                       | Description                                                                                           | | ||||
| | ------- | ------------------------------------------ | ----------------------------------------------------------------------------------------------------- | | ||||
| | `vocab` | [`Vocab`](/api/vocab)                      | The shared vocabulary that's passed in on initialization.                                             | | ||||
| | `model` | [`Model`](https://thinc.ai/docs/api-model) | The model powering the component.                                                                     | | ||||
| | `name`  | str                                        | The name of the component instance in the pipeline. Can be used in the losses.                        | | ||||
| | `cfg`   | dict                                       | Keyword arguments passed to [`Pipe.__init__`](/api/pipe#init). Will be serialized with the component. | | ||||
| 
 | ||||
| ## Serialization fields {#serialization-fields} | ||||
| 
 | ||||
| During serialization, spaCy will export several data fields used to restore | ||||
|  |  | |||
|  | @ -5,7 +5,6 @@ menu: | |||
|   - ['Processing Text', 'processing'] | ||||
|   - ['How Pipelines Work', 'pipelines'] | ||||
|   - ['Custom Components', 'custom-components'] | ||||
|   # - ['Trainable Components', 'trainable-components'] | ||||
|   - ['Extension Attributes', 'custom-components-attributes'] | ||||
|   - ['Plugins & Wrappers', 'plugins'] | ||||
| --- | ||||
|  | @ -885,15 +884,117 @@ available, falls back to looking up the regular factory name. | |||
| 
 | ||||
| </Infobox> | ||||
| 
 | ||||
| <!-- TODO: | ||||
| ## Trainable components {#trainable-components new="3"} | ||||
| ### Trainable components {#trainable-components new="3"} | ||||
| 
 | ||||
| spaCy's [`Pipe`](/api/pipe) class helps you implement your own trainable | ||||
| components that have their own model instance, make predictions over `Doc` | ||||
| objects and can be updated using [`spacy train`](/api/cli#train). This lets you | ||||
| plug fully custom machine learning components into your pipeline. | ||||
| plug fully custom machine learning components into your pipeline. You'll need | ||||
| the following: | ||||
| 
 | ||||
| ---> | ||||
| 1. **Model:** A Thinc [`Model`](https://thinc.ai/docs/api-model) instance. This | ||||
|    can be a model using [layers](https://thinc.ai/docs/api-layers) implemented | ||||
|    in Thinc, or a [wrapped model](https://thinc.ai/docs/usage-frameworks) | ||||
|    implemented in PyTorch, TensorFlow, MXNet or a fully custom solution. The | ||||
|    model must take a list of [`Doc`](/api/doc) objects as input and can have any | ||||
|    type of output. | ||||
| 2. **Pipe subclass:** A subclass of [`Pipe`](/api/pipe) that implements at least | ||||
|    two methods: [`Pipe.predict`](/api/pipe#predict) and | ||||
|    [`Pipe.set_annotations`](/api/pipe#set_annotations). | ||||
| 3. **Component factory:** A component factory registered with | ||||
|    [`@Language.factory`](/api/language#factory) that takes the `nlp` object and | ||||
|    component `name` and optional settings provided by the config and returns an | ||||
|    instance of your trainable component. | ||||
| 
 | ||||
| > #### Example | ||||
| > | ||||
| > ```python | ||||
| > from spacy.pipeline import Pipe | ||||
| > from spacy.language import Language | ||||
| > | ||||
| > class TrainableComponent(Pipe): | ||||
| >     def predict(self, docs): | ||||
| >         ... | ||||
| > | ||||
| >     def set_annotations(self, docs, scores): | ||||
| >         ... | ||||
| > | ||||
| > @Language.factory("my_trainable_component") | ||||
| > def make_component(nlp, name, model): | ||||
| >     return TrainableComponent(nlp.vocab, model, name=name) | ||||
| > ``` | ||||
| 
 | ||||
| | Name                                           | Description                                                                                                         | | ||||
| | ---------------------------------------------- | ------------------------------------------------------------------------------------------------------------------- | | ||||
| | [`predict`](/api/pipe#predict)                 | Apply the component's model to a batch of [`Doc`](/api/doc) objects (without modifying them) and return the scores. | | ||||
| | [`set_annotations`](/api/pipe#set_annotations) | Modify a batch of [`Doc`](/api/doc) objects, using pre-computed scores generated by `predict`.                      | | ||||
| 
 | ||||
| By default, [`Pipe.__init__`](/api/pipe#init) takes the shared vocab, the | ||||
| [`Model`](https://thinc.ai/docs/api-model) and the name of the component | ||||
| instance in the pipeline, which you can use as a key in the losses. All other | ||||
| keyword arguments will become available as [`Pipe.cfg`](/api/pipe#cfg) and will | ||||
| also be serialized with the component. | ||||
| 
 | ||||
| <Accordion title="Why components should be passed a Model instance, not create it" spaced> | ||||
| 
 | ||||
| spaCy's [config system](/usage/training#config) resolves the config describing | ||||
| the pipeline components and models **bottom-up**. This means that it will | ||||
| _first_ create a `Model` from a [registered architecture](/api/architectures), | ||||
| validate its arguments and _then_ pass the object forward to the component. This | ||||
| means that the config can express very complex, nested trees of objects – but | ||||
| the objects don't have to pass the model settings all the way down to the | ||||
| components. It also makes the components more **modular** and lets you swap | ||||
| different architectures in your config, and re-use model definitions. | ||||
| 
 | ||||
| ```ini | ||||
| ### config.cfg (excerpt) | ||||
| [components] | ||||
| 
 | ||||
| [components.textcat] | ||||
| factory = "textcat" | ||||
| labels = [] | ||||
| 
 | ||||
| # This function is created and then passed to the "textcat" component as | ||||
| # the argument "model" | ||||
| [components.textcat.model] | ||||
| @architectures = "spacy.TextCatEnsemble.v1" | ||||
| exclusive_classes = false | ||||
| pretrained_vectors = null | ||||
| width = 64 | ||||
| conv_depth = 2 | ||||
| embed_size = 2000 | ||||
| window_size = 1 | ||||
| ngram_size = 1 | ||||
| dropout = null | ||||
| 
 | ||||
| [components.other_textcat] | ||||
| factory = "textcat" | ||||
| # This references the [components.textcat.model] block above | ||||
| model = ${components.textcat.model} | ||||
| labels = [] | ||||
| ``` | ||||
| 
 | ||||
| Your trainable pipeline component factories should therefore always take a | ||||
| `model` argument instead of instantiating the | ||||
| [`Model`](https://thinc.ai/docs/api-model) inside the component. To register | ||||
| custom architectures, you can use the | ||||
| [`@spacy.registry.architectures`](/api/top-level#registry) decorator. Also see | ||||
| the [training guide](/usage/training#config) for details. | ||||
| 
 | ||||
| </Accordion> | ||||
| 
 | ||||
| For some use cases, it makes sense to also overwrite additional methods to | ||||
| customize how the model is updated from examples, how it's initialized, how the | ||||
| loss is calculated and to add evaluation scores to the training output. | ||||
| 
 | ||||
| | Name                                         | Description                                                                                                                                                                                                                                                                                                        | | ||||
| | -------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | | ||||
| | [`update`](/api/pipe#update)                 | Learn from a batch of [`Example`](/api/example) objects containing the predictions and gold-standard annotations, and update the component's model.                                                                                                                                                                | | ||||
| | [`begin_training`](/api/pipe#begin_training) | Initialize the model. Typically calls into [`Model.initialize`](https://thinc.ai/docs/api-model#initialize) and [`Pipe.create_optimizer`](/api/pipe#create_optimizer) if no optimizer is provided.                                                                                                                 | | ||||
| | [`get_loss`](/api/pipe#get_loss)             | Return a tuple of the loss and the gradient for a batch of [`Example`](/api/example) objects.                                                                                                                                                                                                                      | | ||||
| | [`score`](/api/pipe#score)                   | Score a batch of [`Example`](/api/example) objects and return a dictionary of scores. The [`@Language.factory`](/api/language#factory) decorator can define the `default_socre_weights` of the component to decide which keys of the scores to display during training and how they count towards the final score. | | ||||
| 
 | ||||
| <!-- TODO: add more details, examples and maybe an example project --> | ||||
| 
 | ||||
| ## Extension attributes {#custom-components-attributes new="2"} | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user