TrainablePipe (#6213)

* rename Pipe to TrainablePipe

* split functionality between Pipe and TrainablePipe

* remove unnecessary methods from certain components

* cleanup

* hasattr(component, "pipe") should be sufficient again

* remove serialization and vocab/cfg from Pipe

* unify _ensure_examples and validate_examples

* small fixes

* hasattr checks for self.cfg and self.vocab

* make is_resizable and is_trainable properties

* serialize strings.json instead of vocab

* fix KB IO + tests

* fix typos

* more typos

* _added_strings as a set

* few more tests specifically for _added_strings field

* bump to 3.0.0a36
This commit is contained in:
Sofie Van Landeghem 2020-10-08 21:33:49 +02:00 committed by GitHub
parent 5ebd1fc2cf
commit d093d6343b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
44 changed files with 687 additions and 623 deletions

View File

@ -37,6 +37,7 @@ MOD_NAMES = [
"spacy.pipeline.multitask",
"spacy.pipeline.ner",
"spacy.pipeline.pipe",
"spacy.pipeline.trainable_pipe",
"spacy.pipeline.sentencizer",
"spacy.pipeline.senter",
"spacy.pipeline.tagger",

View File

@ -1,6 +1,6 @@
# fmt: off
__title__ = "spacy-nightly"
__version__ = "3.0.0a35"
__version__ = "3.0.0a36"
__download_url__ = "https://github.com/explosion/spacy-models/releases/download"
__compatibility__ = "https://raw.githubusercontent.com/explosion/spacy-models/master/compatibility.json"
__projects__ = "https://github.com/explosion/projects"

View File

@ -522,14 +522,12 @@ class Errors:
E928 = ("A KnowledgeBase can only be serialized to/from from a directory, "
"but the provided argument {loc} points to a file.")
E929 = ("Couldn't read KnowledgeBase from {loc}. The path does not seem to exist.")
E930 = ("Received invalid get_examples callback in `{name}.initialize`. "
E930 = ("Received invalid get_examples callback in `{method}`. "
"Expected function that returns an iterable of Example objects but "
"got: {obj}")
E931 = ("Encountered Pipe subclass without `Pipe.{method}` method in component "
"'{name}'. If the component is trainable and you want to use this "
"method, make sure it's overwritten on the subclass. If your "
"component isn't trainable, add a method that does nothing or "
"don't use the Pipe base class.")
E931 = ("Encountered {parent} subclass without `{parent}.{method}` "
"method in component '{name}'. If you want to use this "
"method, make sure it's overwritten on the subclass.")
E940 = ("Found NaN values in scores.")
E941 = ("Can't find model '{name}'. It looks like you're trying to load a "
"model from a shortcut, which is deprecated as of spaCy v3.0. To "

View File

@ -30,6 +30,7 @@ cdef class KnowledgeBase:
cdef Pool mem
cpdef readonly Vocab vocab
cdef int64_t entity_vector_length
cdef public set _added_strings
# This maps 64bit keys (hash of unique entity string)
# to 64bit values (position of the _KBEntryC struct in the _entries vector).

View File

@ -1,5 +1,7 @@
# cython: infer_types=True, profile=True
from typing import Iterator
from typing import Iterator, Iterable
import srsly
from cymem.cymem cimport Pool
from preshed.maps cimport PreshMap
from cpython.exc cimport PyErr_SetFromErrno
@ -10,13 +12,10 @@ from libcpp.vector cimport vector
from pathlib import Path
import warnings
from spacy.strings import StringStore
from spacy import util
from .typedefs cimport hash_t
from .errors import Errors, Warnings
from . import util
from .util import SimpleFrozenList, ensure_path
cdef class Candidate:
"""A `Candidate` object refers to a textual mention (`alias`) that may or may not be resolved
@ -85,9 +84,6 @@ cdef class KnowledgeBase:
DOCS: https://nightly.spacy.io/api/kb
"""
contents_loc = "contents"
strings_loc = "strings.json"
def __init__(self, Vocab vocab, entity_vector_length):
"""Create a KnowledgeBase."""
self.mem = Pool()
@ -95,8 +91,8 @@ cdef class KnowledgeBase:
self._entry_index = PreshMap()
self._alias_index = PreshMap()
self.vocab = vocab
self.vocab.strings.add("")
self._create_empty_vectors(dummy_hash=self.vocab.strings[""])
self._added_strings = set()
@property
def entity_vector_length(self):
@ -118,12 +114,16 @@ cdef class KnowledgeBase:
def get_alias_strings(self):
return [self.vocab.strings[x] for x in self._alias_index]
def add_string(self, string: str):
self._added_strings.add(string)
return self.vocab.strings.add(string)
def add_entity(self, unicode entity, float freq, vector[float] entity_vector):
"""
Add an entity to the KB, optionally specifying its log probability based on corpus frequency
Return the hash of the entity ID/name at the end.
"""
cdef hash_t entity_hash = self.vocab.strings.add(entity)
cdef hash_t entity_hash = self.add_string(entity)
# Return if this entity was added before
if entity_hash in self._entry_index:
@ -157,7 +157,7 @@ cdef class KnowledgeBase:
cdef hash_t entity_hash
while i < len(entity_list):
# only process this entity if its unique ID hadn't been added before
entity_hash = self.vocab.strings.add(entity_list[i])
entity_hash = self.add_string(entity_list[i])
if entity_hash in self._entry_index:
warnings.warn(Warnings.W018.format(entity=entity_list[i]))
@ -203,7 +203,7 @@ cdef class KnowledgeBase:
if prob_sum > 1.00001:
raise ValueError(Errors.E133.format(alias=alias, sum=prob_sum))
cdef hash_t alias_hash = self.vocab.strings.add(alias)
cdef hash_t alias_hash = self.add_string(alias)
# Check whether this alias was added before
if alias_hash in self._alias_index:
@ -324,26 +324,27 @@ cdef class KnowledgeBase:
return 0.0
def to_disk(self, path):
path = util.ensure_path(path)
def to_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()):
path = ensure_path(path)
if not path.exists():
path.mkdir(parents=True)
if not path.is_dir():
raise ValueError(Errors.E928.format(loc=path))
self.write_contents(path / self.contents_loc)
self.vocab.strings.to_disk(path / self.strings_loc)
serialize = {}
serialize["contents"] = lambda p: self.write_contents(p)
serialize["strings.json"] = lambda p: srsly.write_json(p, self._added_strings)
util.to_disk(path, serialize, exclude)
def from_disk(self, path):
path = util.ensure_path(path)
def from_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()):
path = ensure_path(path)
if not path.exists():
raise ValueError(Errors.E929.format(loc=path))
if not path.is_dir():
raise ValueError(Errors.E928.format(loc=path))
self.read_contents(path / self.contents_loc)
kb_strings = StringStore()
kb_strings.from_disk(path / self.strings_loc)
for string in kb_strings:
self.vocab.strings.add(string)
deserialize = {}
deserialize["contents"] = lambda p: self.read_contents(p)
deserialize["strings.json"] = lambda p: [self.add_string(s) for s in srsly.read_json(p)]
util.from_disk(path, deserialize, exclude)
def write_contents(self, file_path):
cdef Writer writer = Writer(file_path)

View File

@ -20,7 +20,7 @@ from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis
from .training import Example, validate_examples
from .training.initialize import init_vocab, init_tok2vec
from .scorer import Scorer
from .util import registry, SimpleFrozenList
from .util import registry, SimpleFrozenList, _pipe
from .util import SimpleFrozenDict, combine_score_weights, CONFIG_SECTION_ORDER
from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS
from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES
@ -1095,7 +1095,7 @@ class Language:
if (
name not in exclude
and hasattr(proc, "is_trainable")
and proc.is_trainable()
and proc.is_trainable
and proc.model not in (True, False, None)
):
proc.finish_update(sgd)
@ -1194,8 +1194,8 @@ class Language:
doc = Doc(self.vocab, words=["x", "y", "z"])
get_examples = lambda: [Example.from_dict(doc, {})]
if not hasattr(get_examples, "__call__"):
err = Errors.E930.format(name="Language", obj=type(get_examples))
raise ValueError(err)
err = Errors.E930.format(method="Language.initialize", obj=type(get_examples))
raise TypeError(err)
# Make sure the config is interpolated so we can resolve subsections
config = self.config.interpolate()
# These are the settings provided in the [initialize] block in the config
@ -1301,16 +1301,7 @@ class Language:
for name, pipe in self.pipeline:
kwargs = component_cfg.get(name, {})
kwargs.setdefault("batch_size", batch_size)
# non-trainable components may have a pipe() implementation that refers to dummy
# predict and set_annotations methods
if (
not hasattr(pipe, "pipe")
or not hasattr(pipe, "is_trainable")
or not pipe.is_trainable()
):
docs = _pipe(docs, pipe, kwargs)
else:
docs = pipe.pipe(docs, **kwargs)
docs = _pipe(docs, pipe, kwargs)
# iterate over the final generator
if len(self.pipeline):
docs = list(docs)
@ -1417,17 +1408,7 @@ class Language:
kwargs = component_cfg.get(name, {})
# Allow component_cfg to overwrite the top-level kwargs.
kwargs.setdefault("batch_size", batch_size)
# non-trainable components may have a pipe() implementation that refers to dummy
# predict and set_annotations methods
if (
hasattr(proc, "pipe")
and hasattr(proc, "is_trainable")
and proc.is_trainable()
):
f = functools.partial(proc.pipe, **kwargs)
else:
# Apply the function, but yield the doc
f = functools.partial(_pipe, proc=proc, kwargs=kwargs)
f = functools.partial(_pipe, proc=proc, kwargs=kwargs)
pipes.append(f)
if n_process != 1:
@ -1826,19 +1807,6 @@ class DisabledPipes(list):
self[:] = []
def _pipe(
examples: Iterable[Example], proc: Callable[[Doc], Doc], kwargs: Dict[str, Any]
) -> Iterator[Example]:
# We added some args for pipe that __call__ doesn't expect.
kwargs = dict(kwargs)
for arg in ["batch_size"]:
if arg in kwargs:
kwargs.pop(arg)
for eg in examples:
eg = proc(eg, **kwargs)
yield eg
def _apply_pipes(
make_doc: Callable[[str], Doc],
pipes: Iterable[Callable[[Doc], Doc]],

View File

@ -6,6 +6,7 @@ from .entityruler import EntityRuler
from .lemmatizer import Lemmatizer
from .morphologizer import Morphologizer
from .pipe import Pipe
from .trainable_pipe import TrainablePipe
from .senter import SentenceRecognizer
from .sentencizer import Sentencizer
from .tagger import Tagger
@ -21,6 +22,7 @@ __all__ = [
"EntityRuler",
"Morphologizer",
"Lemmatizer",
"TrainablePipe",
"Pipe",
"SentenceRecognizer",
"Sentencizer",

View File

@ -57,6 +57,7 @@ class AttributeRuler(Pipe):
self.attrs = []
self._attrs_unnormed = [] # store for reference
self.indices = []
self._added_strings = set()
def clear(self) -> None:
"""Reset all patterns."""
@ -123,21 +124,6 @@ class AttributeRuler(Pipe):
set_token_attrs(span[index], attrs)
return doc
def pipe(self, stream: Iterable[Doc], *, batch_size: int = 128) -> Iterator[Doc]:
"""Apply the pipe to a stream of documents. This usually happens under
the hood when the nlp object is called on a text and all components are
applied to the Doc.
stream (Iterable[Doc]): A stream of documents.
batch_size (int): The number of documents to buffer.
YIELDS (Doc): Processed documents in order.
DOCS: https://spacy.io/attributeruler/pipe#pipe
"""
for doc in stream:
doc = self(doc)
yield doc
def load_from_tag_map(
self, tag_map: Dict[str, Dict[Union[int, str], Union[int, str]]]
) -> None:
@ -201,12 +187,16 @@ class AttributeRuler(Pipe):
# We need to make a string here, because otherwise the ID we pass back
# will be interpreted as the hash of a string, rather than an ordinal.
key = str(len(self.attrs))
self.matcher.add(self.vocab.strings.add(key), patterns)
self.matcher.add(self.add_string(key), patterns)
self._attrs_unnormed.append(attrs)
attrs = normalize_token_attrs(self.vocab, attrs)
self.attrs.append(attrs)
self.indices.append(index)
def add_string(self, string: str):
self._added_strings.add(string)
return self.vocab.strings.add(string)
def add_patterns(self, patterns: Iterable[AttributeRulerPatternType]) -> None:
"""Add patterns from a list of pattern dicts with the keys as the
arguments to AttributeRuler.add.
@ -266,8 +256,8 @@ class AttributeRuler(Pipe):
DOCS: https://nightly.spacy.io/api/attributeruler#to_bytes
"""
serialize = {}
serialize["vocab"] = self.vocab.to_bytes
serialize["patterns"] = lambda: srsly.msgpack_dumps(self.patterns)
serialize["strings.json"] = lambda: srsly.json_dumps(sorted(self._added_strings))
return util.to_bytes(serialize, exclude)
def from_bytes(
@ -286,7 +276,7 @@ class AttributeRuler(Pipe):
self.add_patterns(srsly.msgpack_loads(b))
deserialize = {
"vocab": lambda b: self.vocab.from_bytes(b),
"strings.json": lambda b: [self.add_string(s) for s in srsly.json_loads(b)],
"patterns": load_patterns,
}
util.from_bytes(bytes_data, deserialize, exclude)
@ -303,7 +293,7 @@ class AttributeRuler(Pipe):
DOCS: https://nightly.spacy.io/api/attributeruler#to_disk
"""
serialize = {
"vocab": lambda p: self.vocab.to_disk(p),
"strings.json": lambda p: srsly.write_json(p, self._added_strings),
"patterns": lambda p: srsly.write_msgpack(p, self.patterns),
}
util.to_disk(path, serialize, exclude)
@ -324,7 +314,7 @@ class AttributeRuler(Pipe):
self.add_patterns(srsly.read_msgpack(p))
deserialize = {
"vocab": lambda p: self.vocab.from_disk(p),
"strings.json": lambda p: [self.add_string(s) for s in srsly.read_json(p)],
"patterns": load_patterns,
}
util.from_disk(path, deserialize, exclude)

View File

@ -10,10 +10,11 @@ import warnings
from ..kb import KnowledgeBase, Candidate
from ..ml import empty_kb
from ..tokens import Doc
from .pipe import Pipe, deserialize_config
from .pipe import deserialize_config
from .trainable_pipe import TrainablePipe
from ..language import Language
from ..vocab import Vocab
from ..training import Example, validate_examples
from ..training import Example, validate_examples, validate_get_examples
from ..errors import Errors, Warnings
from ..util import SimpleFrozenList
from .. import util
@ -90,7 +91,7 @@ def make_entity_linker(
)
class EntityLinker(Pipe):
class EntityLinker(TrainablePipe):
"""Pipeline component for named entity linking.
DOCS: https://nightly.spacy.io/api/entitylinker
@ -172,7 +173,7 @@ class EntityLinker(Pipe):
DOCS: https://nightly.spacy.io/api/entitylinker#initialize
"""
self._ensure_examples(get_examples)
validate_get_examples(get_examples, "EntityLinker.initialize")
if kb_loader is not None:
self.set_kb(kb_loader)
self.validate_kb()
@ -453,7 +454,6 @@ class EntityLinker(Pipe):
"""
serialize = {}
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
serialize["vocab"] = lambda p: self.vocab.to_disk(p)
serialize["kb"] = lambda p: self.kb.to_disk(p)
serialize["model"] = lambda p: self.model.to_disk(p)
util.to_disk(path, serialize, exclude)
@ -477,11 +477,12 @@ class EntityLinker(Pipe):
raise ValueError(Errors.E149) from None
deserialize = {}
deserialize["vocab"] = lambda p: self.vocab.from_disk(p)
deserialize["cfg"] = lambda p: self.cfg.update(deserialize_config(p))
deserialize["kb"] = lambda p: self.kb.from_disk(p)
deserialize["model"] = load_model
util.from_disk(path, deserialize, exclude)
for s in self.kb._added_strings:
self.vocab.strings.add(s)
return self
def rehearse(self, examples, *, sgd=None, losses=None, **config):

View File

@ -342,12 +342,6 @@ class EntityRuler(Pipe):
validate_examples(examples, "EntityRuler.score")
return Scorer.score_spans(examples, "ents", **kwargs)
def predict(self, docs):
pass
def set_annotations(self, docs, scores):
pass
def from_bytes(
self, patterns_bytes: bytes, *, exclude: Iterable[str] = SimpleFrozenList()
) -> "EntityRuler":

View File

@ -281,7 +281,6 @@ class Lemmatizer(Pipe):
DOCS: https://nightly.spacy.io/api/lemmatizer#to_disk
"""
serialize = {}
serialize["vocab"] = lambda p: self.vocab.to_disk(p)
serialize["lookups"] = lambda p: self.lookups.to_disk(p)
util.to_disk(path, serialize, exclude)
@ -297,7 +296,6 @@ class Lemmatizer(Pipe):
DOCS: https://nightly.spacy.io/api/lemmatizer#from_disk
"""
deserialize = {}
deserialize["vocab"] = lambda p: self.vocab.from_disk(p)
deserialize["lookups"] = lambda p: self.lookups.from_disk(p)
util.from_disk(path, deserialize, exclude)
self._validate_tables()
@ -312,7 +310,6 @@ class Lemmatizer(Pipe):
DOCS: https://nightly.spacy.io/api/lemmatizer#to_bytes
"""
serialize = {}
serialize["vocab"] = self.vocab.to_bytes
serialize["lookups"] = self.lookups.to_bytes
return util.to_bytes(serialize, exclude)
@ -328,7 +325,6 @@ class Lemmatizer(Pipe):
DOCS: https://nightly.spacy.io/api/lemmatizer#from_bytes
"""
deserialize = {}
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b)
deserialize["lookups"] = lambda b: self.lookups.from_bytes(b)
util.from_bytes(bytes_data, deserialize, exclude)
self._validate_tables()

View File

@ -16,7 +16,7 @@ from .pipe import deserialize_config
from .tagger import Tagger
from .. import util
from ..scorer import Scorer
from ..training import validate_examples
from ..training import validate_examples, validate_get_examples
default_model_config = """
@ -95,6 +95,7 @@ class Morphologizer(Tagger):
# add mappings for empty morph
self.cfg["labels_morph"][Morphology.EMPTY_MORPH] = Morphology.EMPTY_MORPH
self.cfg["labels_pos"][Morphology.EMPTY_MORPH] = POS_IDS[""]
self._added_strings = set()
@property
def labels(self):
@ -128,6 +129,7 @@ class Morphologizer(Tagger):
label_dict.pop(self.POS_FEAT)
# normalize morph string and add to morphology table
norm_morph = self.vocab.strings[self.vocab.morphology.add(label_dict)]
self.add_string(norm_morph)
# add label mappings
if norm_label not in self.cfg["labels_morph"]:
self.cfg["labels_morph"][norm_label] = norm_morph
@ -144,7 +146,7 @@ class Morphologizer(Tagger):
DOCS: https://nightly.spacy.io/api/morphologizer#initialize
"""
self._ensure_examples(get_examples)
validate_get_examples(get_examples, "Morphologizer.initialize")
if labels is not None:
self.cfg["labels_morph"] = labels["morph"]
self.cfg["labels_pos"] = labels["pos"]
@ -159,6 +161,7 @@ class Morphologizer(Tagger):
if pos:
morph_dict[self.POS_FEAT] = pos
norm_label = self.vocab.strings[self.vocab.morphology.add(morph_dict)]
self.add_string(norm_label)
# add label->morph and label->POS mappings
if norm_label not in self.cfg["labels_morph"]:
self.cfg["labels_morph"][norm_label] = morph
@ -176,6 +179,7 @@ class Morphologizer(Tagger):
if pos:
morph_dict[self.POS_FEAT] = pos
norm_label = self.vocab.strings[self.vocab.morphology.add(morph_dict)]
self.add_string(norm_label)
gold_array.append([1.0 if label == norm_label else 0.0 for label in self.labels])
doc_sample.append(example.x)
label_sample.append(self.model.ops.asarray(gold_array, dtype="float32"))
@ -234,6 +238,7 @@ class Morphologizer(Tagger):
if pos:
label_dict[self.POS_FEAT] = pos
label = self.vocab.strings[self.vocab.morphology.add(label_dict)]
self.add_string(label)
eg_truths.append(label)
truths.append(eg_truths)
d_scores, loss = loss_func(scores, truths)

View File

@ -6,7 +6,7 @@ from thinc.api import set_dropout_rate
from ..tokens.doc cimport Doc
from .pipe import Pipe
from .trainable_pipe import TrainablePipe
from .tagger import Tagger
from ..training import validate_examples
from ..language import Language
@ -164,7 +164,7 @@ class MultitaskObjective(Tagger):
return "I-SENT"
class ClozeMultitask(Pipe):
class ClozeMultitask(TrainablePipe):
def __init__(self, vocab, model, **cfg):
self.vocab = vocab
self.model = model

View File

@ -1,5 +1,2 @@
cdef class Pipe:
cdef public object vocab
cdef public object model
cdef public str name
cdef public object cfg

View File

@ -1,38 +1,22 @@
# cython: infer_types=True, profile=True
import warnings
from typing import Optional, Tuple
from typing import Optional, Tuple, Iterable, Iterator, Callable, Union, Dict
import srsly
from thinc.api import set_dropout_rate, Model
from ..tokens.doc cimport Doc
from ..training import validate_examples
from ..training import Example
from ..errors import Errors, Warnings
from .. import util
from ..language import Language
cdef class Pipe:
"""This class is a base class and not instantiated directly. Trainable
pipeline components like the EntityRecognizer or TextCategorizer inherit
from it and it defines the interface that components should follow to
function as trainable components in a spaCy pipeline.
"""This class is a base class and not instantiated directly. It provides
an interface for pipeline components to implement.
Trainable pipeline components like the EntityRecognizer or TextCategorizer
should inherit from the subclass 'TrainablePipe'.
DOCS: https://nightly.spacy.io/api/pipe
"""
def __init__(self, vocab, model, name, **cfg):
"""Initialize a pipeline component.
vocab (Vocab): The shared vocabulary.
model (thinc.api.Model): The Thinc Model powering the pipeline component.
name (str): The component instance name.
**cfg: Additonal settings and config parameters.
DOCS: https://nightly.spacy.io/api/pipe#init
"""
self.vocab = vocab
self.model = model
self.name = name
self.cfg = dict(cfg)
@classmethod
def __init_subclass__(cls, **kwargs):
@ -41,18 +25,7 @@ cdef class Pipe:
if hasattr(cls, "begin_training"):
warnings.warn(Warnings.W088.format(name=cls.__name__))
@property
def labels(self) -> Optional[Tuple[str]]:
return []
@property
def label_data(self):
"""Optional JSON-serializable data that would be sufficient to recreate
the label set if provided to the `pipe.initialize()` method.
"""
return None
def __call__(self, Doc doc):
def __call__(self, Doc doc) -> Doc:
"""Apply the pipe to one document. The document is modified in place,
and returned. This usually happens under the hood when the nlp object
is called on a text and all components are applied to the Doc.
@ -62,11 +35,9 @@ cdef class Pipe:
DOCS: https://nightly.spacy.io/api/pipe#call
"""
scores = self.predict([doc])
self.set_annotations([doc], scores)
return doc
raise NotImplementedError(Errors.E931.format(parent="Pipe", method="__call__", name=self.name))
def pipe(self, stream, *, batch_size=128):
def pipe(self, stream: Iterable[Doc], *, batch_size: int=128) -> Iterator[Doc]:
"""Apply the pipe to a stream of documents. This usually happens under
the hood when the nlp object is called on a text and all components are
applied to the Doc.
@ -77,137 +48,17 @@ cdef class Pipe:
DOCS: https://nightly.spacy.io/api/pipe#pipe
"""
for docs in util.minibatch(stream, size=batch_size):
scores = self.predict(docs)
self.set_annotations(docs, scores)
yield from docs
for doc in stream:
doc = self(doc)
yield doc
def predict(self, docs):
"""Apply the pipeline's model to a batch of docs, without modifying them.
Returns a single tensor for a batch of documents.
docs (Iterable[Doc]): The documents to predict.
RETURNS: Vector representations for each token in the documents.
DOCS: https://nightly.spacy.io/api/pipe#predict
"""
raise NotImplementedError(Errors.E931.format(method="predict", name=self.name))
def set_annotations(self, docs, scores):
"""Modify a batch of documents, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify.
scores: The scores to assign.
DOCS: https://nightly.spacy.io/api/pipe#set_annotations
"""
raise NotImplementedError(Errors.E931.format(method="set_annotations", name=self.name))
def update(self, examples, *, drop=0.0, set_annotations=False, sgd=None, losses=None):
"""Learn from a batch of documents and gold-standard information,
updating the pipe's model. Delegates to predict and get_loss.
examples (Iterable[Example]): A batch of Example objects.
drop (float): The dropout rate.
set_annotations (bool): Whether or not to update the Example objects
with the predictions.
sgd (thinc.api.Optimizer): The optimizer.
losses (Dict[str, float]): Optional record of the loss during training.
Updated using the component name as the key.
RETURNS (Dict[str, float]): The updated losses dictionary.
DOCS: https://nightly.spacy.io/api/pipe#update
"""
if losses is None:
losses = {}
if not hasattr(self, "model") or self.model in (None, True, False):
return losses
losses.setdefault(self.name, 0.0)
validate_examples(examples, "Pipe.update")
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
# Handle cases where there are no tokens in any docs.
return
set_dropout_rate(self.model, drop)
scores, bp_scores = self.model.begin_update([eg.predicted for eg in examples])
loss, d_scores = self.get_loss(examples, scores)
bp_scores(d_scores)
if sgd not in (None, False):
self.finish_update(sgd)
losses[self.name] += loss
if set_annotations:
docs = [eg.predicted for eg in examples]
self.set_annotations(docs, scores=scores)
return losses
def rehearse(self, examples, *, sgd=None, losses=None, **config):
"""Perform a "rehearsal" update from a batch of data. Rehearsal updates
teach the current model to make predictions similar to an initial model,
to try to address the "catastrophic forgetting" problem. This feature is
experimental.
examples (Iterable[Example]): A batch of Example objects.
drop (float): The dropout rate.
sgd (thinc.api.Optimizer): The optimizer.
losses (Dict[str, float]): Optional record of the loss during training.
Updated using the component name as the key.
RETURNS (Dict[str, float]): The updated losses dictionary.
DOCS: https://nightly.spacy.io/api/pipe#rehearse
"""
pass
def get_loss(self, examples, scores):
"""Find the loss and gradient of loss for the batch of documents and
their predicted scores.
examples (Iterable[Examples]): The batch of examples.
scores: Scores representing the model's predictions.
RETURNS (Tuple[float, float]): The loss and the gradient.
DOCS: https://nightly.spacy.io/api/pipe#get_loss
"""
raise NotImplementedError(Errors.E931.format(method="get_loss", name=self.name))
def add_label(self, label):
"""Add an output label, to be predicted by the model. It's possible to
extend pretrained models with new labels, but care should be taken to
avoid the "catastrophic forgetting" problem.
label (str): The label to add.
RETURNS (int): 0 if label is already present, otherwise 1.
DOCS: https://nightly.spacy.io/api/pipe#add_label
"""
raise NotImplementedError(Errors.E931.format(method="add_label", name=self.name))
def _require_labels(self) -> None:
"""Raise an error if the component's model has no labels defined."""
if not self.labels or list(self.labels) == [""]:
raise ValueError(Errors.E143.format(name=self.name))
def _allow_extra_label(self) -> None:
"""Raise an error if the component can not add any more labels."""
if self.model.has_dim("nO") and self.model.get_dim("nO") == len(self.labels):
if not self.is_resizable():
raise ValueError(Errors.E922.format(name=self.name, nO=self.model.get_dim("nO")))
def create_optimizer(self):
"""Create an optimizer for the pipeline component.
RETURNS (thinc.api.Optimizer): The optimizer.
DOCS: https://nightly.spacy.io/api/pipe#create_optimizer
"""
return util.create_default_optimizer()
def initialize(self, get_examples, *, nlp=None):
"""Initialize the pipe for training, using data examples if available.
This method needs to be implemented by each Pipe component,
ensuring the internal model (if available) is initialized properly
using the provided sample of Example objects.
def initialize(self, get_examples: Callable[[], Iterable[Example]], *, nlp: Language=None):
"""Initialize the pipe. For non-trainable components, this method
is optional. For trainable components, which should inherit
from the subclass TrainablePipe, the provided data examples
should be used to ensure that the internal model is initialized
properly and all input/output dimensions throughout the network are
inferred.
get_examples (Callable[[], Iterable[Example]]): Function that
returns a representative sample of gold-standard Example objects.
@ -217,49 +68,7 @@ cdef class Pipe:
"""
pass
def _ensure_examples(self, get_examples):
if get_examples is None or not hasattr(get_examples, "__call__"):
err = Errors.E930.format(name=self.name, obj=type(get_examples))
raise ValueError(err)
if not get_examples():
err = Errors.E930.format(name=self.name, obj=get_examples())
raise ValueError(err)
def is_resizable(self):
return hasattr(self, "model") and "resize_output" in self.model.attrs
def is_trainable(self):
return hasattr(self, "model") and isinstance(self.model, Model)
def set_output(self, nO):
if self.is_resizable():
self.model.attrs["resize_output"](self.model, nO)
else:
raise NotImplementedError(Errors.E921)
def use_params(self, params):
"""Modify the pipe's model, to use the given parameter values. At the
end of the context, the original parameters are restored.
params (dict): The parameter values to use in the model.
DOCS: https://nightly.spacy.io/api/pipe#use_params
"""
with self.model.use_params(params):
yield
def finish_update(self, sgd):
"""Update parameters using the current parameter gradients.
The Optimizer instance contains the functionality to perform
the stochastic gradient descent.
sgd (thinc.api.Optimizer): The optimizer.
DOCS: https://nightly.spacy.io/api/pipe#finish_update
"""
self.model.finish_update(sgd)
def score(self, examples, **kwargs):
def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Union[float, Dict[str, float]]]:
"""Score a batch of examples.
examples (Iterable[Example]): The examples to score.
@ -269,81 +78,25 @@ cdef class Pipe:
"""
return {}
def to_bytes(self, *, exclude=tuple()):
"""Serialize the pipe to a bytestring.
@property
def is_trainable(self) -> bool:
return False
exclude (Iterable[str]): String names of serialization fields to exclude.
RETURNS (bytes): The serialized object.
@property
def labels(self) -> Optional[Tuple[str]]:
return tuple()
DOCS: https://nightly.spacy.io/api/pipe#to_bytes
@property
def label_data(self):
"""Optional JSON-serializable data that would be sufficient to recreate
the label set if provided to the `pipe.initialize()` method.
"""
serialize = {}
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
serialize["model"] = self.model.to_bytes
if hasattr(self, "vocab"):
serialize["vocab"] = self.vocab.to_bytes
return util.to_bytes(serialize, exclude)
def from_bytes(self, bytes_data, *, exclude=tuple()):
"""Load the pipe from a bytestring.
exclude (Iterable[str]): String names of serialization fields to exclude.
RETURNS (Pipe): The loaded object.
DOCS: https://nightly.spacy.io/api/pipe#from_bytes
"""
def load_model(b):
try:
self.model.from_bytes(b)
except AttributeError:
raise ValueError(Errors.E149) from None
deserialize = {}
if hasattr(self, "vocab"):
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b)
deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
deserialize["model"] = load_model
util.from_bytes(bytes_data, deserialize, exclude)
return self
def to_disk(self, path, *, exclude=tuple()):
"""Serialize the pipe to disk.
path (str / Path): Path to a directory.
exclude (Iterable[str]): String names of serialization fields to exclude.
DOCS: https://nightly.spacy.io/api/pipe#to_disk
"""
serialize = {}
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
serialize["vocab"] = lambda p: self.vocab.to_disk(p)
serialize["model"] = lambda p: self.model.to_disk(p)
util.to_disk(path, serialize, exclude)
def from_disk(self, path, *, exclude=tuple()):
"""Load the pipe from disk.
path (str / Path): Path to a directory.
exclude (Iterable[str]): String names of serialization fields to exclude.
RETURNS (Pipe): The loaded object.
DOCS: https://nightly.spacy.io/api/pipe#from_disk
"""
def load_model(p):
try:
self.model.from_bytes(p.open("rb").read())
except AttributeError:
raise ValueError(Errors.E149) from None
deserialize = {}
deserialize["vocab"] = lambda p: self.vocab.from_disk(p)
deserialize["cfg"] = lambda p: self.cfg.update(deserialize_config(p))
deserialize["model"] = load_model
util.from_disk(path, deserialize, exclude)
return self
return None
def _require_labels(self) -> None:
"""Raise an error if this component has no labels defined."""
if not self.labels or list(self.labels) == [""]:
raise ValueError(Errors.E143.format(name=self.name))
def deserialize_config(path):
if path.exists():

View File

@ -58,9 +58,6 @@ class Sentencizer(Pipe):
else:
self.punct_chars = set(self.default_punct_chars)
def initialize(self, get_examples, nlp=None):
pass
def __call__(self, doc):
"""Apply the sentencizer to a Doc and set Token.is_sent_start.
@ -204,9 +201,3 @@ class Sentencizer(Pipe):
cfg = srsly.read_json(path)
self.punct_chars = set(cfg.get("punct_chars", self.default_punct_chars))
return self
def get_loss(self, examples, scores):
raise NotImplementedError
def add_label(self, label):
raise NotImplementedError

View File

@ -6,12 +6,11 @@ from thinc.api import Model, SequenceCategoricalCrossentropy, Config
from ..tokens.doc cimport Doc
from .pipe import deserialize_config
from .tagger import Tagger
from ..language import Language
from ..errors import Errors
from ..scorer import Scorer
from ..training import validate_examples
from ..training import validate_examples, validate_get_examples
from .. import util
@ -62,6 +61,7 @@ class SentenceRecognizer(Tagger):
self.name = name
self._rehearsal_model = None
self.cfg = {}
self._added_strings = set()
@property
def labels(self):
@ -138,7 +138,7 @@ class SentenceRecognizer(Tagger):
DOCS: https://nightly.spacy.io/api/sentencerecognizer#initialize
"""
self._ensure_examples(get_examples)
validate_get_examples(get_examples, "SentenceRecognizer.initialize")
doc_sample = []
label_sample = []
assert self.labels, Errors.E924.format(name=self.name)

View File

@ -11,13 +11,14 @@ from ..tokens.doc cimport Doc
from ..morphology cimport Morphology
from ..vocab cimport Vocab
from .pipe import Pipe, deserialize_config
from .trainable_pipe import TrainablePipe
from .pipe import deserialize_config
from ..language import Language
from ..attrs import POS, ID
from ..parts_of_speech import X
from ..errors import Errors, Warnings
from ..scorer import Scorer
from ..training import validate_examples
from ..training import validate_examples, validate_get_examples
from .. import util
@ -55,7 +56,7 @@ def make_tagger(nlp: Language, name: str, model: Model):
return Tagger(nlp.vocab, model, name)
class Tagger(Pipe):
class Tagger(TrainablePipe):
"""Pipeline component for part-of-speech tagging.
DOCS: https://nightly.spacy.io/api/tagger
@ -77,6 +78,7 @@ class Tagger(Pipe):
self._rehearsal_model = None
cfg = {"labels": labels or []}
self.cfg = dict(sorted(cfg.items()))
self._added_strings = set()
@property
def labels(self):
@ -274,7 +276,7 @@ class Tagger(Pipe):
DOCS: https://nightly.spacy.io/api/tagger#initialize
"""
self._ensure_examples(get_examples)
validate_get_examples(get_examples, "Tagger.initialize")
if labels is not None:
for tag in labels:
self.add_label(tag)
@ -311,7 +313,7 @@ class Tagger(Pipe):
return 0
self._allow_extra_label()
self.cfg["labels"].append(label)
self.vocab.strings.add(label)
self.add_string(label)
return 1
def score(self, examples, **kwargs):
@ -325,79 +327,3 @@ class Tagger(Pipe):
"""
validate_examples(examples, "Tagger.score")
return Scorer.score_token_attr(examples, "tag", **kwargs)
def to_bytes(self, *, exclude=tuple()):
"""Serialize the pipe to a bytestring.
exclude (Iterable[str]): String names of serialization fields to exclude.
RETURNS (bytes): The serialized object.
DOCS: https://nightly.spacy.io/api/tagger#to_bytes
"""
serialize = {}
serialize["model"] = self.model.to_bytes
serialize["vocab"] = self.vocab.to_bytes
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
return util.to_bytes(serialize, exclude)
def from_bytes(self, bytes_data, *, exclude=tuple()):
"""Load the pipe from a bytestring.
bytes_data (bytes): The serialized pipe.
exclude (Iterable[str]): String names of serialization fields to exclude.
RETURNS (Tagger): The loaded Tagger.
DOCS: https://nightly.spacy.io/api/tagger#from_bytes
"""
def load_model(b):
try:
self.model.from_bytes(b)
except AttributeError:
raise ValueError(Errors.E149) from None
deserialize = {
"vocab": lambda b: self.vocab.from_bytes(b),
"cfg": lambda b: self.cfg.update(srsly.json_loads(b)),
"model": lambda b: load_model(b),
}
util.from_bytes(bytes_data, deserialize, exclude)
return self
def to_disk(self, path, *, exclude=tuple()):
"""Serialize the pipe to disk.
path (str / Path): Path to a directory.
exclude (Iterable[str]): String names of serialization fields to exclude.
DOCS: https://nightly.spacy.io/api/tagger#to_disk
"""
serialize = {
"vocab": lambda p: self.vocab.to_disk(p),
"model": lambda p: self.model.to_disk(p),
"cfg": lambda p: srsly.write_json(p, self.cfg),
}
util.to_disk(path, serialize, exclude)
def from_disk(self, path, *, exclude=tuple()):
"""Load the pipe from disk. Modifies the object in place and returns it.
path (str / Path): Path to a directory.
exclude (Iterable[str]): String names of serialization fields to exclude.
RETURNS (Tagger): The modified Tagger object.
DOCS: https://nightly.spacy.io/api/tagger#from_disk
"""
def load_model(p):
with p.open("rb") as file_:
try:
self.model.from_bytes(file_.read())
except AttributeError:
raise ValueError(Errors.E149) from None
deserialize = {
"vocab": lambda p: self.vocab.from_disk(p),
"cfg": lambda p: self.cfg.update(deserialize_config(p)),
"model": load_model,
}
util.from_disk(path, deserialize, exclude)
return self

View File

@ -4,9 +4,9 @@ from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Conf
from thinc.types import Floats2d
import numpy
from .pipe import Pipe
from .trainable_pipe import TrainablePipe
from ..language import Language
from ..training import Example, validate_examples
from ..training import Example, validate_examples, validate_get_examples
from ..errors import Errors
from ..scorer import Scorer
from .. import util
@ -85,7 +85,7 @@ def make_textcat(
return TextCategorizer(nlp.vocab, model, name, threshold=threshold)
class TextCategorizer(Pipe):
class TextCategorizer(TrainablePipe):
"""Pipeline component for text classification.
DOCS: https://nightly.spacy.io/api/textcategorizer
@ -110,6 +110,7 @@ class TextCategorizer(Pipe):
self._rehearsal_model = None
cfg = {"labels": [], "threshold": threshold, "positive_label": None}
self.cfg = dict(cfg)
self._added_strings = set()
@property
def labels(self) -> Tuple[str]:
@ -119,13 +120,6 @@ class TextCategorizer(Pipe):
"""
return tuple(self.cfg["labels"])
@labels.setter
def labels(self, value: List[str]) -> None:
# TODO: This really shouldn't be here. I had a look and I added it when
# I added the labels property, but it's pretty nasty to have this, and
# will lead to problems.
self.cfg["labels"] = tuple(value)
@property
def label_data(self) -> List[str]:
"""RETURNS (List[str]): Information about the component's labels."""
@ -306,7 +300,8 @@ class TextCategorizer(Pipe):
if label in self.labels:
return 0
self._allow_extra_label()
self.labels = tuple(list(self.labels) + [label])
self.cfg["labels"].append(label)
self.add_string(label)
return 1
def initialize(
@ -329,7 +324,7 @@ class TextCategorizer(Pipe):
DOCS: https://nightly.spacy.io/api/textcategorizer#initialize
"""
self._ensure_examples(get_examples)
validate_get_examples(get_examples, "TextCategorizer.initialize")
if labels is None:
for example in get_examples():
for cat in example.y.cats:

View File

@ -2,8 +2,8 @@ from typing import Iterator, Sequence, Iterable, Optional, Dict, Callable, List
from thinc.api import Model, set_dropout_rate, Optimizer, Config
from itertools import islice
from .pipe import Pipe
from ..training import Example, validate_examples
from .trainable_pipe import TrainablePipe
from ..training import Example, validate_examples, validate_get_examples
from ..tokens import Doc
from ..vocab import Vocab
from ..language import Language
@ -32,7 +32,7 @@ def make_tok2vec(nlp: Language, name: str, model: Model) -> "Tok2Vec":
return Tok2Vec(nlp.vocab, model, name)
class Tok2Vec(Pipe):
class Tok2Vec(TrainablePipe):
"""Apply a "token-to-vector" model and set its outputs in the doc.tensor
attribute. This is mostly useful to share a single subnetwork between multiple
components, e.g. to have one embedding and CNN network shared between a
@ -64,6 +64,7 @@ class Tok2Vec(Pipe):
self.name = name
self.listeners = []
self.cfg = {}
self._added_strings = set()
def add_listener(self, listener: "Tok2VecListener") -> None:
"""Add a listener for a downstream component. Usually internals."""
@ -218,7 +219,7 @@ class Tok2Vec(Pipe):
DOCS: https://nightly.spacy.io/api/tok2vec#initialize
"""
self._ensure_examples(get_examples)
validate_get_examples(get_examples, "Tok2Vec.initialize")
doc_sample = []
for example in islice(get_examples(), 10):
doc_sample.append(example.x)

View File

@ -0,0 +1,8 @@
from .pipe cimport Pipe
from ..vocab cimport Vocab
cdef class TrainablePipe(Pipe):
cdef public Vocab vocab
cdef public object model
cdef public object cfg
cdef public set _added_strings

View File

@ -0,0 +1,322 @@
# cython: infer_types=True, profile=True
from typing import Iterable, Iterator, Optional, Dict, Tuple, Callable
import srsly
from thinc.api import set_dropout_rate, Model, Optimizer
from ..tokens.doc cimport Doc
from ..training import validate_examples
from ..errors import Errors
from .pipe import Pipe, deserialize_config
from .. import util
from ..vocab import Vocab
from ..language import Language
from ..training import Example
cdef class TrainablePipe(Pipe):
"""This class is a base class and not instantiated directly. Trainable
pipeline components like the EntityRecognizer or TextCategorizer inherit
from it and it defines the interface that components should follow to
function as trainable components in a spaCy pipeline.
DOCS: https://nightly.spacy.io/api/pipe
"""
def __init__(self, vocab: Vocab, model: Model, name: str, **cfg):
"""Initialize a pipeline component.
vocab (Vocab): The shared vocabulary.
model (thinc.api.Model): The Thinc Model powering the pipeline component.
name (str): The component instance name.
**cfg: Additonal settings and config parameters.
DOCS: https://nightly.spacy.io/api/pipe#init
"""
self.vocab = vocab
self.model = model
self.name = name
self.cfg = dict(cfg)
self._added_strings = set()
def __call__(self, Doc doc) -> Doc:
"""Apply the pipe to one document. The document is modified in place,
and returned. This usually happens under the hood when the nlp object
is called on a text and all components are applied to the Doc.
docs (Doc): The Doc to process.
RETURNS (Doc): The processed Doc.
DOCS: https://nightly.spacy.io/api/pipe#call
"""
scores = self.predict([doc])
self.set_annotations([doc], scores)
return doc
def pipe(self, stream: Iterable[Doc], *, batch_size: int=128) -> Iterator[Doc]:
"""Apply the pipe to a stream of documents. This usually happens under
the hood when the nlp object is called on a text and all components are
applied to the Doc.
stream (Iterable[Doc]): A stream of documents.
batch_size (int): The number of documents to buffer.
YIELDS (Doc): Processed documents in order.
DOCS: https://nightly.spacy.io/api/pipe#pipe
"""
for docs in util.minibatch(stream, size=batch_size):
scores = self.predict(docs)
self.set_annotations(docs, scores)
yield from docs
def predict(self, docs: Iterable[Doc]):
"""Apply the pipeline's model to a batch of docs, without modifying them.
Returns a single tensor for a batch of documents.
docs (Iterable[Doc]): The documents to predict.
RETURNS: Vector representations of the predictions.
DOCS: https://nightly.spacy.io/api/pipe#predict
"""
raise NotImplementedError(Errors.E931.format(parent="TrainablePipe", method="predict", name=self.name))
def set_annotations(self, docs: Iterable[Doc], scores):
"""Modify a batch of documents, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify.
scores: The scores to assign.
DOCS: https://nightly.spacy.io/api/pipe#set_annotations
"""
raise NotImplementedError(Errors.E931.format(parent="TrainablePipe", method="set_annotations", name=self.name))
def update(self,
examples: Iterable["Example"],
*, drop: float=0.0,
set_annotations: bool=False,
sgd: Optimizer=None,
losses: Optional[Dict[str, float]]=None) -> Dict[str, float]:
"""Learn from a batch of documents and gold-standard information,
updating the pipe's model. Delegates to predict and get_loss.
examples (Iterable[Example]): A batch of Example objects.
drop (float): The dropout rate.
set_annotations (bool): Whether or not to update the Example objects
with the predictions.
sgd (thinc.api.Optimizer): The optimizer.
losses (Dict[str, float]): Optional record of the loss during training.
Updated using the component name as the key.
RETURNS (Dict[str, float]): The updated losses dictionary.
DOCS: https://nightly.spacy.io/api/pipe#update
"""
if losses is None:
losses = {}
if not hasattr(self, "model") or self.model in (None, True, False):
return losses
losses.setdefault(self.name, 0.0)
validate_examples(examples, "TrainablePipe.update")
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
# Handle cases where there are no tokens in any docs.
return
set_dropout_rate(self.model, drop)
scores, bp_scores = self.model.begin_update([eg.predicted for eg in examples])
loss, d_scores = self.get_loss(examples, scores)
bp_scores(d_scores)
if sgd not in (None, False):
self.finish_update(sgd)
losses[self.name] += loss
if set_annotations:
docs = [eg.predicted for eg in examples]
self.set_annotations(docs, scores=scores)
return losses
def rehearse(self,
examples: Iterable[Example],
*,
sgd: Optimizer=None,
losses: Dict[str, float]=None,
**config) -> Dict[str, float]:
"""Perform a "rehearsal" update from a batch of data. Rehearsal updates
teach the current model to make predictions similar to an initial model,
to try to address the "catastrophic forgetting" problem. This feature is
experimental.
examples (Iterable[Example]): A batch of Example objects.
sgd (thinc.api.Optimizer): The optimizer.
losses (Dict[str, float]): Optional record of the loss during training.
Updated using the component name as the key.
RETURNS (Dict[str, float]): The updated losses dictionary.
DOCS: https://nightly.spacy.io/api/pipe#rehearse
"""
pass
def get_loss(self, examples: Iterable[Example], scores) -> Tuple[float, float]:
"""Find the loss and gradient of loss for the batch of documents and
their predicted scores.
examples (Iterable[Examples]): The batch of examples.
scores: Scores representing the model's predictions.
RETURNS (Tuple[float, float]): The loss and the gradient.
DOCS: https://nightly.spacy.io/api/pipe#get_loss
"""
raise NotImplementedError(Errors.E931.format(parent="TrainablePipe", method="get_loss", name=self.name))
def create_optimizer(self) -> Optimizer:
"""Create an optimizer for the pipeline component.
RETURNS (thinc.api.Optimizer): The optimizer.
DOCS: https://nightly.spacy.io/api/pipe#create_optimizer
"""
return util.create_default_optimizer()
def initialize(self, get_examples: Callable[[], Iterable[Example]], *, nlp: Language=None):
"""Initialize the pipe for training, using data examples if available.
This method needs to be implemented by each TrainablePipe component,
ensuring the internal model (if available) is initialized properly
using the provided sample of Example objects.
get_examples (Callable[[], Iterable[Example]]): Function that
returns a representative sample of gold-standard Example objects.
nlp (Language): The current nlp object the component is part of.
DOCS: https://nightly.spacy.io/api/pipe#initialize
"""
raise NotImplementedError(Errors.E931.format(parent="TrainablePipe", method="initialize", name=self.name))
def add_label(self, label: str) -> int:
"""Add an output label.
For TrainablePipe components, it is possible to
extend pretrained models with new labels, but care should be taken to
avoid the "catastrophic forgetting" problem.
label (str): The label to add.
RETURNS (int): 0 if label is already present, otherwise 1.
DOCS: https://nightly.spacy.io/api/pipe#add_label
"""
raise NotImplementedError(Errors.E931.format(parent="Pipe", method="add_label", name=self.name))
def add_string(self, string: str):
self._added_strings.add(string)
return self.vocab.strings.add(string)
@property
def is_trainable(self) -> bool:
return True
@property
def is_resizable(self) -> bool:
return getattr(self, "model", None) and "resize_output" in self.model.attrs
def _allow_extra_label(self) -> None:
"""Raise an error if the component can not add any more labels."""
if self.model.has_dim("nO") and self.model.get_dim("nO") == len(self.labels):
if not self.is_resizable:
raise ValueError(Errors.E922.format(name=self.name, nO=self.model.get_dim("nO")))
def set_output(self, nO: int) -> None:
if self.is_resizable:
self.model.attrs["resize_output"](self.model, nO)
else:
raise NotImplementedError(Errors.E921)
def use_params(self, params: dict):
"""Modify the pipe's model, to use the given parameter values. At the
end of the context, the original parameters are restored.
params (dict): The parameter values to use in the model.
DOCS: https://nightly.spacy.io/api/pipe#use_params
"""
with self.model.use_params(params):
yield
def finish_update(self, sgd: Optimizer) -> None:
"""Update parameters using the current parameter gradients.
The Optimizer instance contains the functionality to perform
the stochastic gradient descent.
sgd (thinc.api.Optimizer): The optimizer.
DOCS: https://nightly.spacy.io/api/pipe#finish_update
"""
self.model.finish_update(sgd)
def to_bytes(self, *, exclude=tuple()):
"""Serialize the pipe to a bytestring.
exclude (Iterable[str]): String names of serialization fields to exclude.
RETURNS (bytes): The serialized object.
DOCS: https://nightly.spacy.io/api/pipe#to_bytes
"""
serialize = {}
if hasattr(self, "cfg"):
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
serialize["model"] = self.model.to_bytes
serialize["strings.json"] = lambda: srsly.json_dumps(sorted(self._added_strings))
return util.to_bytes(serialize, exclude)
def from_bytes(self, bytes_data, *, exclude=tuple()):
"""Load the pipe from a bytestring.
exclude (Iterable[str]): String names of serialization fields to exclude.
RETURNS (TrainablePipe): The loaded object.
DOCS: https://nightly.spacy.io/api/pipe#from_bytes
"""
def load_model(b):
try:
self.model.from_bytes(b)
except AttributeError:
raise ValueError(Errors.E149) from None
deserialize = {}
deserialize["strings.json"] = lambda b: [self.add_string(s) for s in srsly.json_loads(b)]
if hasattr(self, "cfg"):
deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
deserialize["model"] = load_model
util.from_bytes(bytes_data, deserialize, exclude)
return self
def to_disk(self, path, *, exclude=tuple()):
"""Serialize the pipe to disk.
path (str / Path): Path to a directory.
exclude (Iterable[str]): String names of serialization fields to exclude.
DOCS: https://nightly.spacy.io/api/pipe#to_disk
"""
serialize = {}
if hasattr(self, "cfg"):
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
serialize["strings.json"] = lambda p: srsly.write_json(p, self._added_strings)
serialize["model"] = lambda p: self.model.to_disk(p)
util.to_disk(path, serialize, exclude)
def from_disk(self, path, *, exclude=tuple()):
"""Load the pipe from disk.
path (str / Path): Path to a directory.
exclude (Iterable[str]): String names of serialization fields to exclude.
RETURNS (TrainablePipe): The loaded object.
DOCS: https://nightly.spacy.io/api/pipe#from_disk
"""
def load_model(p):
try:
self.model.from_bytes(p.open("rb").read())
except AttributeError:
raise ValueError(Errors.E149) from None
deserialize = {}
deserialize["strings.json"] = lambda p: [self.add_string(s) for s in srsly.read_json(p)]
if hasattr(self, "cfg"):
deserialize["cfg"] = lambda p: self.cfg.update(deserialize_config(p))
deserialize["model"] = load_model
util.from_disk(path, deserialize, exclude)
return self

View File

@ -1,13 +1,13 @@
from cymem.cymem cimport Pool
from ..vocab cimport Vocab
from .pipe cimport Pipe
from .trainable_pipe cimport TrainablePipe
from ._parser_internals.transition_system cimport Transition, TransitionSystem
from ._parser_internals._state cimport StateC
from ..ml.parser_model cimport WeightsC, ActivationsC, SizesC
cdef class Parser(Pipe):
cdef class Parser(TrainablePipe):
cdef public object _rehearsal_model
cdef readonly TransitionSystem moves
cdef public object _multitasks

View File

@ -21,13 +21,14 @@ from ..ml.parser_model cimport predict_states, arg_max_if_valid
from ..ml.parser_model cimport WeightsC, ActivationsC, SizesC, cpu_log_loss
from ..ml.parser_model cimport get_c_weights, get_c_sizes
from ..tokens.doc cimport Doc
from .trainable_pipe import TrainablePipe
from ..training import validate_examples
from ..training import validate_examples, validate_get_examples
from ..errors import Errors, Warnings
from .. import util
cdef class Parser(Pipe):
cdef class Parser(TrainablePipe):
"""
Base class of the DependencyParser and EntityRecognizer.
"""
@ -75,6 +76,7 @@ cdef class Parser(Pipe):
self.add_multitask_objective(multitask)
self._rehearsal_model = None
self._added_strings = set()
def __getnewargs_ex__(self):
"""This allows pickling the Parser and its keyword-only init arguments"""
@ -118,6 +120,7 @@ cdef class Parser(Pipe):
resized = True
if resized:
self._resize()
self.add_string(label)
return 1
return 0
@ -411,7 +414,7 @@ cdef class Parser(Pipe):
self.model.attrs["resize_output"](self.model, nO)
def initialize(self, get_examples, nlp=None, labels=None):
self._ensure_examples(get_examples)
validate_get_examples(get_examples, "Parser.initialize")
lexeme_norms = self.vocab.lookups.get_table("lexeme_norm", {})
if len(lexeme_norms) == 0 and self.vocab.lang in util.LEXEME_NORM_LANGS:
langs = ", ".join(util.LEXEME_NORM_LANGS)
@ -439,7 +442,7 @@ cdef class Parser(Pipe):
break
# non-trainable components may have a pipe() implementation that refers to dummy
# predict and set_annotations methods
if hasattr(component, "pipe") and hasattr(component, "is_trainable") and component.is_trainable():
if hasattr(component, "pipe"):
doc_sample = list(component.pipe(doc_sample, batch_size=8))
else:
doc_sample = [component(doc) for doc in doc_sample]
@ -454,7 +457,7 @@ cdef class Parser(Pipe):
def to_disk(self, path, exclude=tuple()):
serializers = {
'model': lambda p: (self.model.to_disk(p) if self.model is not True else True),
'vocab': lambda p: self.vocab.to_disk(p),
'strings.json': lambda p: srsly.write_json(p, self._added_strings),
'moves': lambda p: self.moves.to_disk(p, exclude=["strings"]),
'cfg': lambda p: srsly.write_json(p, self.cfg)
}
@ -462,7 +465,7 @@ cdef class Parser(Pipe):
def from_disk(self, path, exclude=tuple()):
deserializers = {
'vocab': lambda p: self.vocab.from_disk(p),
'strings.json': lambda p: [self.add_string(s) for s in srsly.read_json(p)],
'moves': lambda p: self.moves.from_disk(p, exclude=["strings"]),
'cfg': lambda p: self.cfg.update(srsly.read_json(p)),
'model': lambda p: None,
@ -482,7 +485,7 @@ cdef class Parser(Pipe):
def to_bytes(self, exclude=tuple()):
serializers = {
"model": lambda: (self.model.to_bytes()),
"vocab": lambda: self.vocab.to_bytes(),
"strings.json": lambda: srsly.json_dumps(sorted(self._added_strings)),
"moves": lambda: self.moves.to_bytes(exclude=["strings"]),
"cfg": lambda: srsly.json_dumps(self.cfg, indent=2, sort_keys=True)
}
@ -490,7 +493,7 @@ cdef class Parser(Pipe):
def from_bytes(self, bytes_data, exclude=tuple()):
deserializers = {
"vocab": lambda b: self.vocab.from_bytes(b),
"strings.json": lambda b: [self.add_string(s) for s in srsly.json_loads(b)],
"moves": lambda b: self.moves.from_bytes(b, exclude=["strings"]),
"cfg": lambda b: self.cfg.update(srsly.json_loads(b)),
"model": lambda b: None,

View File

@ -368,7 +368,7 @@ class ConfigSchemaInit(BaseModel):
vectors: Optional[StrictStr] = Field(..., title="Path to vectors")
init_tok2vec: Optional[StrictStr] = Field(..., title="Path to pretrained tok2vec weights")
tokenizer: Dict[StrictStr, Any] = Field(..., help="Arguments to be passed into Tokenizer.initialize")
components: Dict[StrictStr, Dict[StrictStr, Any]] = Field(..., help="Arguments for Pipe.initialize methods of pipeline components, keyed by component")
components: Dict[StrictStr, Dict[StrictStr, Any]] = Field(..., help="Arguments for TrainablePipe.initialize methods of pipeline components, keyed by component")
# fmt: on
class Config:

View File

@ -133,7 +133,7 @@ def test_kb_custom_length(nlp):
def test_kb_initialize_empty(nlp):
"""Test that the EL can't initialize without examples"""
entity_linker = nlp.add_pipe("entity_linker")
with pytest.raises(ValueError):
with pytest.raises(TypeError):
entity_linker.initialize(lambda: [])
@ -153,6 +153,23 @@ def test_kb_serialize(nlp):
mykb.from_disk(d / "unknown" / "kb")
def test_kb_serialize_vocab(nlp):
"""Test serialization of the KB and custom strings"""
entity = "MyFunnyID"
assert entity not in nlp.vocab.strings
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
assert not mykb.contains_entity(entity)
mykb.add_entity(entity, freq=342, entity_vector=[3])
assert mykb.contains_entity(entity)
assert entity in mykb.vocab.strings
with make_tempdir() as d:
# normal read-write behaviour
mykb.to_disk(d / "kb")
mykb_new = KnowledgeBase(Vocab(), entity_vector_length=1)
mykb_new.from_disk(d / "kb")
assert entity in mykb_new.vocab.strings
def test_candidate_generation(nlp):
"""Test correct candidate generation"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
@ -413,6 +430,7 @@ def test_overfitting_IO():
# Simple test to try and quickly overfit the NEL component - ensuring the ML models work correctly
nlp = English()
vector_length = 3
assert "Q2146908" not in nlp.vocab.strings
# Convert the texts to docs to make sure we have doc.ents set for the training examples
train_examples = []
@ -440,6 +458,9 @@ def test_overfitting_IO():
last=True,
)
entity_linker.set_kb(create_kb)
assert "Q2146908" in entity_linker.vocab.strings
assert "Q2146908" in entity_linker.kb.vocab.strings
assert "Q2146908" in entity_linker.kb._added_strings
# train the NEL pipe
optimizer = nlp.initialize(get_examples=lambda: train_examples)
@ -474,6 +495,10 @@ def test_overfitting_IO():
nlp.to_disk(tmp_dir)
nlp2 = util.load_model_from_path(tmp_dir)
assert nlp2.pipe_names == nlp.pipe_names
assert "Q2146908" in nlp2.vocab.strings
entity_linker2 = nlp2.get_pipe("entity_linker")
assert "Q2146908" in entity_linker2.vocab.strings
assert "Q2146908" in entity_linker2.kb.vocab.strings
predictions = []
for text, annotation in TRAIN_DATA:
doc2 = nlp2(text)

View File

@ -66,9 +66,9 @@ def test_initialize_examples():
# you shouldn't really call this more than once, but for testing it should be fine
nlp.initialize()
nlp.initialize(get_examples=lambda: train_examples)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
nlp.initialize(get_examples=lambda: None)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
nlp.initialize(get_examples=train_examples)
@ -101,3 +101,4 @@ def test_overfitting_IO():
doc2 = nlp2(test_text)
assert [str(t.morph) for t in doc2] == gold_morphs
assert [t.pos_ for t in doc2] == gold_pos_tags
assert nlp.get_pipe("morphologizer")._added_strings == nlp2.get_pipe("morphologizer")._added_strings

View File

@ -1,6 +1,6 @@
import pytest
from spacy.language import Language
from spacy.pipeline import Pipe
from spacy.pipeline import TrainablePipe
from spacy.util import SimpleFrozenList, get_arg_names
@ -376,7 +376,7 @@ def test_pipe_label_data_no_labels(pipe):
def test_warning_pipe_begin_training():
with pytest.warns(UserWarning, match="begin_training"):
class IncompatPipe(Pipe):
class IncompatPipe(TrainablePipe):
def __init__(self):
...

View File

@ -40,9 +40,9 @@ def test_initialize_examples():
# you shouldn't really call this more than once, but for testing it should be fine
nlp.initialize()
nlp.initialize(get_examples=lambda: train_examples)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
nlp.initialize(get_examples=lambda: None)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
nlp.initialize(get_examples=train_examples)
@ -80,3 +80,4 @@ def test_overfitting_IO():
nlp2 = util.load_model_from_path(tmp_dir)
doc2 = nlp2(test_text)
assert [int(t.is_sent_start) for t in doc2] == gold_sent_starts
assert nlp.get_pipe("senter")._added_strings == nlp2.get_pipe("senter")._added_strings

View File

@ -74,13 +74,13 @@ def test_initialize_examples():
# you shouldn't really call this more than once, but for testing it should be fine
nlp.initialize()
nlp.initialize(get_examples=lambda: train_examples)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
nlp.initialize(get_examples=lambda: None)
with pytest.raises(TypeError):
nlp.initialize(get_examples=lambda: train_examples[0])
with pytest.raises(ValueError):
with pytest.raises(TypeError):
nlp.initialize(get_examples=lambda: [])
with pytest.raises(ValueError):
with pytest.raises(TypeError):
nlp.initialize(get_examples=train_examples)
@ -98,6 +98,7 @@ def test_overfitting_IO():
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
assert losses["tagger"] < 0.00001
assert tagger._added_strings == {"J", "N", "V"}
# test the trained model
test_text = "I like blue eggs"
@ -116,6 +117,7 @@ def test_overfitting_IO():
assert doc2[1].tag_ is "V"
assert doc2[2].tag_ is "J"
assert doc2[3].tag_ is "N"
assert nlp2.get_pipe("tagger")._added_strings == {"J", "N", "V"}
def test_tagger_requires_labels():

View File

@ -127,9 +127,9 @@ def test_initialize_examples():
nlp.initialize()
get_examples = make_get_examples(nlp)
nlp.initialize(get_examples=get_examples)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
nlp.initialize(get_examples=lambda: None)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
nlp.initialize(get_examples=get_examples())
@ -146,6 +146,7 @@ def test_overfitting_IO():
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
optimizer = nlp.initialize(get_examples=lambda: train_examples)
assert textcat.model.get_dim("nO") == 2
assert textcat._added_strings == {"NEGATIVE", "POSITIVE"}
for i in range(50):
losses = {}
@ -167,6 +168,7 @@ def test_overfitting_IO():
cats2 = doc2.cats
assert cats2["POSITIVE"] > 0.9
assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.001)
assert nlp2.get_pipe("textcat")._added_strings == {"NEGATIVE", "POSITIVE"}
# Test scoring
scores = nlp.evaluate(train_examples)

View File

@ -1,5 +1,5 @@
import pytest
from spacy.pipeline import Pipe
from spacy.pipeline import TrainablePipe
from spacy.matcher import PhraseMatcher, Matcher
from spacy.tokens import Doc, Span, DocBin
from spacy.training import Example, Corpus
@ -271,7 +271,7 @@ def test_issue4272():
def test_multiple_predictions():
class DummyPipe(Pipe):
class DummyPipe(TrainablePipe):
def __init__(self):
self.model = "dummy_model"

View File

@ -1,4 +1,3 @@
from typing import Callable
import warnings
from unittest import TestCase
import pytest
@ -7,8 +6,7 @@ from numpy import zeros
from spacy.kb import KnowledgeBase, Writer
from spacy.vectors import Vectors
from spacy.language import Language
from spacy.pipeline import Pipe
from spacy.util import registry
from spacy.pipeline import TrainablePipe
from ..util import make_tempdir
@ -45,14 +43,13 @@ def custom_pipe():
def from_disk(self, path, exclude=tuple(), **kwargs):
return self
class MyPipe(Pipe):
class MyPipe(TrainablePipe):
def __init__(self, vocab, model=True, **cfg):
if cfg:
self.cfg = cfg
else:
self.cfg = None
self.model = SerializableDummy()
self.vocab = SerializableDummy()
return MyPipe(None)

View File

@ -1,5 +1,6 @@
import pytest
from spacy import registry
import srsly
from spacy import registry, Vocab
from spacy.pipeline import Tagger, DependencyParser, EntityRecognizer
from spacy.pipeline import TextCategorizer, SentenceRecognizer
from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL
@ -69,6 +70,29 @@ def test_serialize_parser_roundtrip_bytes(en_vocab, Parser):
assert bytes_2 == bytes_3
@pytest.mark.parametrize("Parser", test_parsers)
def test_serialize_parser_strings(Parser):
vocab1 = Vocab()
label = "FunnyLabel"
assert label not in vocab1.strings
config = {
"learn_tokens": False,
"min_action_freq": 0,
"update_with_oracle_cut_size": 100,
}
cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.resolve(cfg, validate=True)["model"]
parser1 = Parser(vocab1, model, **config)
parser1.add_label(label)
assert label in parser1.vocab.strings
vocab2 = Vocab()
assert label not in vocab2.strings
parser2 = Parser(vocab2, model, **config)
parser2 = parser2.from_bytes(parser1.to_bytes(exclude=["vocab"]))
assert parser1._added_strings == parser2._added_strings == {"FunnyLabel"}
assert label in parser2.vocab.strings
@pytest.mark.parametrize("Parser", test_parsers)
def test_serialize_parser_roundtrip_disk(en_vocab, Parser):
config = {
@ -132,6 +156,29 @@ def test_serialize_tagger_roundtrip_disk(en_vocab, taggers):
assert tagger1_d.to_bytes() == tagger2_d.to_bytes()
def test_serialize_tagger_strings(en_vocab, de_vocab, taggers):
label = "SomeWeirdLabel"
assert label not in en_vocab.strings
assert label not in de_vocab.strings
tagger = taggers[0]
assert label not in tagger.vocab.strings
with make_tempdir() as d:
# check that custom labels are serialized as part of the component's strings.jsonl
tagger.add_label(label)
assert label in tagger.vocab.strings
assert tagger._added_strings == {label}
file_path = d / "tagger1"
tagger.to_disk(file_path)
strings = srsly.read_json(file_path / "strings.json")
assert strings == ["SomeWeirdLabel"]
# ensure that the custom strings are loaded back in when using the tagger in another pipeline
cfg = {"model": DEFAULT_TAGGER_MODEL}
model = registry.resolve(cfg, validate=True)["model"]
tagger2 = Tagger(de_vocab, model).from_disk(file_path)
assert label in tagger2.vocab.strings
assert tagger2._added_strings == {label}
def test_serialize_textcat_empty(en_vocab):
# See issue #1105
cfg = {"model": DEFAULT_TEXTCAT_MODEL}

View File

@ -1,5 +1,5 @@
from .corpus import Corpus # noqa: F401
from .example import Example, validate_examples # noqa: F401
from .example import Example, validate_examples, validate_get_examples # noqa: F401
from .align import Alignment # noqa: F401
from .augment import dont_augment, orth_variants_augmenter # noqa: F401
from .iob_utils import iob_to_biluo, biluo_to_iob # noqa: F401

View File

@ -44,6 +44,24 @@ def validate_examples(examples, method):
raise TypeError(err)
def validate_get_examples(get_examples, method):
"""Check that a generator of a batch of examples received during processing is valid:
the callable produces a non-empty list of Example objects.
This function lives here to prevent circular imports.
get_examples (Callable[[], Iterable[Example]]): A function that produces a batch of examples.
method (str): The method name to show in error messages.
"""
if get_examples is None or not hasattr(get_examples, "__call__"):
err = Errors.E930.format(method=method, obj=type(get_examples))
raise TypeError(err)
examples = get_examples()
if not examples:
err = Errors.E930.format(method=method, obj=examples)
raise TypeError(err)
validate_examples(examples, method)
cdef class Example:
def __init__(self, Doc predicted, Doc reference, *, alignment=None):
if predicted is None:

View File

@ -21,7 +21,7 @@ def console_logger(progress_bar: bool = False):
logged_pipes = [
name
for name, proc in nlp.pipeline
if hasattr(proc, "is_trainable") and proc.is_trainable()
if hasattr(proc, "is_trainable") and proc.is_trainable
]
eval_frequency = nlp.config["training"]["eval_frequency"]
score_weights = nlp.config["training"]["score_weights"]

View File

@ -188,7 +188,7 @@ def train_while_improving(
if (
name not in exclude
and hasattr(proc, "is_trainable")
and proc.is_trainable()
and proc.is_trainable
and proc.model not in (True, False, None)
):
proc.finish_update(optimizer)

View File

@ -1356,3 +1356,16 @@ def check_bool_env_var(env_var: str) -> bool:
if value == "0":
return False
return bool(value)
def _pipe(docs, proc, kwargs):
if hasattr(proc, "pipe"):
yield from proc.pipe(docs, **kwargs)
# We added some args for pipe that __call__ doesn't expect.
kwargs = dict(kwargs)
for arg in ["batch_size"]:
if arg in kwargs:
kwargs.pop(arg)
for doc in docs:
doc = proc(doc, **kwargs)
yield doc

View File

@ -1,5 +1,5 @@
---
title: Pipe
title: TrainablePipe
tag: class
teaser: Base class for trainable pipeline components
---
@ -10,30 +10,32 @@ components like the [`EntityRecognizer`](/api/entityrecognizer) or
interface that components should follow to function as trainable components in a
spaCy pipeline. See the docs on
[writing trainable components](/usage/processing-pipelines#trainable-components)
for how to use the `Pipe` base class to implement custom components.
for how to use the `TrainablePipe` base class to implement custom components.
> #### Why is Pipe implemented in Cython?
<!-- TODO: Pipe vs TrainablePipe, check methods below (all renamed to TrainablePipe for now) -->
> #### Why is TrainablePipe implemented in Cython?
>
> The `Pipe` class is implemented in a `.pyx` module, the extension used by
> [Cython](/api/cython). This is needed so that **other** Cython classes, like
> the [`EntityRecognizer`](/api/entityrecognizer) can inherit from it. But it
> doesn't mean you have to implement trainable components in Cython pure
> Python components like the [`TextCategorizer`](/api/textcategorizer) can also
> inherit from `Pipe`.
> The `TrainablePipe` class is implemented in a `.pyx` module, the extension
> used by [Cython](/api/cython). This is needed so that **other** Cython
> classes, like the [`EntityRecognizer`](/api/entityrecognizer) can inherit from
> it. But it doesn't mean you have to implement trainable components in Cython
> pure Python components like the [`TextCategorizer`](/api/textcategorizer) can
> also inherit from `TrainablePipe`.
```python
%%GITHUB_SPACY/spacy/pipeline/pipe.pyx
%%GITHUB_SPACY/spacy/pipeline/trainable_pipe.pyx
```
## Pipe.\_\_init\_\_ {#init tag="method"}
## TrainablePipe.\_\_init\_\_ {#init tag="method"}
> #### Example
>
> ```python
> from spacy.pipeline import Pipe
> from spacy.pipeline import TrainablePipe
> from spacy.language import Language
>
> class CustomPipe(Pipe):
> class CustomPipe(TrainablePipe):
> ...
>
> @Language.factory("your_custom_pipe", default_config={"model": MODEL})
@ -45,14 +47,14 @@ Create a new pipeline instance. In your application, you would normally use a
shortcut for this and instantiate the component using its string name and
[`nlp.add_pipe`](/api/language#create_pipe).
| Name | Description |
| ------- | ------------------------------------------------------------------------------------------------------------------------------- |
| `vocab` | The shared vocabulary. ~~Vocab~~ |
| `model` | The Thinc [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. ~~Model[List[Doc], Any]~~ |
| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ |
| `**cfg` | Additional config parameters and settings. Will be available as the dictionary `Pipe.cfg` and is serialized with the component. |
| Name | Description |
| ------- | -------------------------------------------------------------------------------------------------------------------------- |
| `vocab` | The shared vocabulary. ~~Vocab~~ |
| `model` | The Thinc [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. ~~Model[List[Doc], Any]~~ |
| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ |
| `**cfg` | Additional config parameters and settings. Will be available as the dictionary `cfg` and is serialized with the component. |
## Pipe.\_\_call\_\_ {#call tag="method"}
## TrainablePipe.\_\_call\_\_ {#call tag="method"}
Apply the pipe to one document. The document is modified in place, and returned.
This usually happens under the hood when the `nlp` object is called on a text
@ -75,7 +77,7 @@ and all pipeline components are applied to the `Doc` in order. Both
| `doc` | The document to process. ~~Doc~~ |
| **RETURNS** | The processed document. ~~Doc~~ |
## Pipe.pipe {#pipe tag="method"}
## TrainablePipe.pipe {#pipe tag="method"}
Apply the pipe to a stream of documents. This usually happens under the hood
when the `nlp` object is called on a text and all pipeline components are
@ -98,7 +100,7 @@ applied to the `Doc` in order. Both [`__call__`](/api/pipe#call) and
| `batch_size` | The number of documents to buffer. Defaults to `128`. ~~int~~ |
| **YIELDS** | The processed documents in order. ~~Doc~~ |
## Pipe.initialize {#initialize tag="method" new="3"}
## TrainablePipe.initialize {#initialize tag="method" new="3"}
Initialize the component for training. `get_examples` should be a function that
returns an iterable of [`Example`](/api/example) objects. The data examples are
@ -128,7 +130,7 @@ This method was previously called `begin_training`.
| _keyword-only_ | |
| `nlp` | The current `nlp` object. Defaults to `None`. ~~Optional[Language]~~ |
## Pipe.predict {#predict tag="method"}
## TrainablePipe.predict {#predict tag="method"}
Apply the component's model to a batch of [`Doc`](/api/doc) objects, without
modifying them.
@ -151,7 +153,7 @@ This method needs to be overwritten with your own custom `predict` method.
| `docs` | The documents to predict. ~~Iterable[Doc]~~ |
| **RETURNS** | The model's prediction for each document. |
## Pipe.set_annotations {#set_annotations tag="method"}
## TrainablePipe.set_annotations {#set_annotations tag="method"}
Modify a batch of [`Doc`](/api/doc) objects, using pre-computed scores.
@ -175,7 +177,7 @@ method.
| `docs` | The documents to modify. ~~Iterable[Doc]~~ |
| `scores` | The scores to set, produced by `Tagger.predict`. |
## Pipe.update {#update tag="method"}
## TrainablePipe.update {#update tag="method"}
Learn from a batch of [`Example`](/api/example) objects containing the
predictions and gold-standard annotations, and update the component's model.
@ -198,7 +200,7 @@ predictions and gold-standard annotations, and update the component's model.
| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
## Pipe.rehearse {#rehearse tag="method,experimental" new="3"}
## TrainablePipe.rehearse {#rehearse tag="method,experimental" new="3"}
Perform a "rehearsal" update from a batch of data. Rehearsal updates teach the
current model to make predictions similar to an initial model, to try to address
@ -216,12 +218,11 @@ the "catastrophic forgetting" problem. This feature is experimental.
| -------------- | ------------------------------------------------------------------------------------------------------------------------ |
| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ |
| _keyword-only_ | |
| `drop` | The dropout rate. ~~float~~ |
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
## Pipe.get_loss {#get_loss tag="method"}
## TrainablePipe.get_loss {#get_loss tag="method"}
Find the loss and gradient of loss for the batch of documents and their
predicted scores.
@ -246,7 +247,7 @@ This method needs to be overwritten with your own custom `get_loss` method.
| `scores` | Scores representing the model's predictions. |
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
## Pipe.score {#score tag="method" new="3"}
## TrainablePipe.score {#score tag="method" new="3"}
Score a batch of examples.
@ -261,7 +262,7 @@ Score a batch of examples.
| `examples` | The examples to score. ~~Iterable[Example]~~ |
| **RETURNS** | The scores, e.g. produced by the [`Scorer`](/api/scorer). ~~Dict[str, Union[float, Dict[str, float]]]~~ |
## Pipe.create_optimizer {#create_optimizer tag="method"}
## TrainablePipe.create_optimizer {#create_optimizer tag="method"}
Create an optimizer for the pipeline component. Defaults to
[`Adam`](https://thinc.ai/docs/api-optimizers#adam) with default settings.
@ -277,7 +278,7 @@ Create an optimizer for the pipeline component. Defaults to
| ----------- | ---------------------------- |
| **RETURNS** | The optimizer. ~~Optimizer~~ |
## Pipe.use_params {#use_params tag="method, contextmanager"}
## TrainablePipe.use_params {#use_params tag="method, contextmanager"}
Modify the pipe's model, to use the given parameter values. At the end of the
context, the original parameters are restored.
@ -294,7 +295,7 @@ context, the original parameters are restored.
| -------- | -------------------------------------------------- |
| `params` | The parameter values to use in the model. ~~dict~~ |
## Pipe.finish_update {#finish_update tag="method"}
## TrainablePipe.finish_update {#finish_update tag="method"}
Update parameters using the current parameter gradients. Defaults to calling
[`self.model.finish_update`](https://thinc.ai/docs/api-model#finish_update).
@ -312,7 +313,7 @@ Update parameters using the current parameter gradients. Defaults to calling
| ----- | ------------------------------------- |
| `sgd` | An optimizer. ~~Optional[Optimizer]~~ |
## Pipe.add_label {#add_label tag="method"}
## TrainablePipe.add_label {#add_label tag="method"}
> #### Example
>
@ -347,12 +348,12 @@ case, all labels found in the sample will be automatically added to the model,
and the output dimension will be
[inferred](/usage/layers-architectures#thinc-shape-inference) automatically.
## Pipe.is_resizable {#is_resizable tag="method"}
## TrainablePipe.is_resizable {#is_resizable tag="property"}
> #### Example
>
> ```python
> can_resize = pipe.is_resizable()
> can_resize = pipe.is_resizable
> ```
>
> With custom resizing implemented by a component:
@ -378,7 +379,7 @@ as an attribute to the component's model.
| ----------- | ---------------------------------------------------------------------------------------------- |
| **RETURNS** | Whether or not the output dimension of the model can be changed after initialization. ~~bool~~ |
## Pipe.set_output {#set_output tag="method"}
## TrainablePipe.set_output {#set_output tag="method"}
Change the output dimension of the component's model. If the component is not
[resizable](#is_resizable), this method will raise a `NotImplementedError`. If a
@ -390,7 +391,7 @@ care should be taken to avoid the "catastrophic forgetting" problem.
> #### Example
>
> ```python
> if pipe.is_resizable():
> if pipe.is_resizable:
> pipe.set_output(512)
> ```
@ -398,7 +399,7 @@ care should be taken to avoid the "catastrophic forgetting" problem.
| ---- | --------------------------------- |
| `nO` | The new output dimension. ~~int~~ |
## Pipe.to_disk {#to_disk tag="method"}
## TrainablePipe.to_disk {#to_disk tag="method"}
Serialize the pipe to disk.
@ -415,7 +416,7 @@ Serialize the pipe to disk.
| _keyword-only_ | |
| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ |
## Pipe.from_disk {#from_disk tag="method"}
## TrainablePipe.from_disk {#from_disk tag="method"}
Load the pipe from disk. Modifies the object in place and returns it.
@ -431,9 +432,9 @@ Load the pipe from disk. Modifies the object in place and returns it.
| `path` | A path to a directory. Paths may be either strings or `Path`-like objects. ~~Union[str, Path]~~ |
| _keyword-only_ | |
| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ |
| **RETURNS** | The modified pipe. ~~Pipe~~ |
| **RETURNS** | The modified pipe. ~~TrainablePipe~~ |
## Pipe.to_bytes {#to_bytes tag="method"}
## TrainablePipe.to_bytes {#to_bytes tag="method"}
> #### Example
>
@ -450,7 +451,7 @@ Serialize the pipe to a bytestring.
| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ |
| **RETURNS** | The serialized form of the pipe. ~~bytes~~ |
## Pipe.from_bytes {#from_bytes tag="method"}
## TrainablePipe.from_bytes {#from_bytes tag="method"}
Load the pipe from a bytestring. Modifies the object in place and returns it.
@ -467,16 +468,16 @@ Load the pipe from a bytestring. Modifies the object in place and returns it.
| `bytes_data` | The data to load from. ~~bytes~~ |
| _keyword-only_ | |
| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ |
| **RETURNS** | The pipe. ~~Pipe~~ |
| **RETURNS** | The pipe. ~~TrainablePipe~~ |
## Attributes {#attributes}
| Name | Description |
| ------- | ------------------------------------------------------------------------------------------------------------------------ |
| `vocab` | The shared vocabulary that's passed in on initialization. ~~Vocab~~ |
| `model` | The model powering the component. ~~Model[List[Doc], Any]~~ |
| `name` | The name of the component instance in the pipeline. Can be used in the losses. ~~str~~ |
| `cfg` | Keyword arguments passed to [`Pipe.__init__`](/api/pipe#init). Will be serialized with the component. ~~Dict[str, Any]~~ |
| Name | Description |
| ------- | --------------------------------------------------------------------------------------------------------------------------------- |
| `vocab` | The shared vocabulary that's passed in on initialization. ~~Vocab~~ |
| `model` | The model powering the component. ~~Model[List[Doc], Any]~~ |
| `name` | The name of the component instance in the pipeline. Can be used in the losses. ~~str~~ |
| `cfg` | Keyword arguments passed to [`TrainablePipe.__init__`](/api/pipe#init). Will be serialized with the component. ~~Dict[str, Any]~~ |
## Serialization fields {#serialization-fields}
@ -487,11 +488,10 @@ serialization by passing in the string names via the `exclude` argument.
> #### Example
>
> ```python
> data = pipe.to_disk("/path", exclude=["vocab"])
> data = pipe.to_disk("/path")
> ```
| Name | Description |
| ------- | -------------------------------------------------------------- |
| `vocab` | The shared [`Vocab`](/api/vocab). |
| `cfg` | The config file. You usually don't want to exclude this. |
| `model` | The binary model data. You usually don't want to exclude this. |

View File

@ -57,7 +57,8 @@ components for different language processing tasks and also allows adding
| [`Sentencizer`](/api/sentencizer) | Implement rule-based sentence boundary detection that doesn't require the dependency parse. |
| [`SentenceRecognizer`](/api/sentencerecognizer) | Predict sentence boundaries. |
| [Other functions](/api/pipeline-functions) | Automatically apply something to the `Doc`, e.g. to merge spans of tokens. |
| [`Pipe`](/api/pipe) | Base class that all trainable pipeline components inherit from. |
| [`Pipe`](/api/pipe) | Base class that pipeline components may inherit from. |
| [`TrainablePipe`](/api/pipe) | Class that all trainable pipeline components inherit from. |
### Matchers {#architecture-matchers}

View File

@ -491,13 +491,14 @@ In addition to [swapping out](#swap-architectures) default models in built-in
components, you can also implement an entirely new,
[trainable](/usage/processing-pipelines#trainable-components) pipeline component
from scratch. This can be done by creating a new class inheriting from
[`Pipe`](/api/pipe), and linking it up to your custom model implementation.
[`TrainablePipe`](/api/pipe), and linking it up to your custom model
implementation.
<Infobox title="Trainable component API" emoji="💡">
For details on how to implement pipeline components, check out the usage guide
on [custom components](/usage/processing-pipelines#custom-component) and the
overview of the `Pipe` methods used by
overview of the `TrainablePipe` methods used by
[trainable components](/usage/processing-pipelines#trainable-components).
</Infobox>
@ -646,15 +647,15 @@ get_candidates = model.attrs["get_candidates"]
To use our new relation extraction model as part of a custom
[trainable component](/usage/processing-pipelines#trainable-components), we
create a subclass of [`Pipe`](/api/pipe) that holds the model.
create a subclass of [`TrainablePipe`](/api/pipe) that holds the model.
![Illustration of Pipe methods](../images/trainable_component.svg)
```python
### Pipeline component skeleton
from spacy.pipeline import Pipe
from spacy.pipeline import TrainablePipe
class RelationExtractor(Pipe):
class RelationExtractor(TrainablePipe):
def __init__(self, vocab, model, name="rel"):
"""Create a component instance."""
self.model = model
@ -757,9 +758,10 @@ def update(
When the internal model is trained, the component can be used to make novel
**predictions**. The [`predict`](/api/pipe#predict) function needs to be
implemented for each subclass of `Pipe`. In our case, we can simply delegate to
the internal model's [predict](https://thinc.ai/docs/api-model#predict) function
that takes a batch of `Doc` objects and returns a ~~Floats2d~~ array:
implemented for each subclass of `TrainablePipe`. In our case, we can simply
delegate to the internal model's
[predict](https://thinc.ai/docs/api-model#predict) function that takes a batch
of `Doc` objects and returns a ~~Floats2d~~ array:
```python
### The predict method
@ -826,7 +828,7 @@ def __call__(self, Doc doc):
return doc
```
Once our `Pipe` subclass is fully implemented, we can
Once our `TrainablePipe` subclass is fully implemented, we can
[register](/usage/processing-pipelines#custom-components-factories) the
component with the [`@Language.factory`](/api/language#factory) decorator. This
assigns it a name and lets you create the component with

View File

@ -1169,10 +1169,10 @@ doc = nlp("This is a text...")
## Trainable components {#trainable-components new="3"}
spaCy's [`Pipe`](/api/pipe) class helps you implement your own trainable
components that have their own model instance, make predictions over `Doc`
objects and can be updated using [`spacy train`](/api/cli#train). This lets you
plug fully custom machine learning components into your pipeline.
spaCy's [`TrainablePipe`](/api/pipe) class helps you implement your own
trainable components that have their own model instance, make predictions over
`Doc` objects and can be updated using [`spacy train`](/api/cli#train). This
lets you plug fully custom machine learning components into your pipeline.
![Illustration of Pipe methods](../images/trainable_component.svg)
@ -1183,9 +1183,9 @@ You'll need the following:
a [wrapped model](/usage/layers-architectures#frameworks) implemented in
PyTorch, TensorFlow, MXNet or a fully custom solution. The model must take a
list of [`Doc`](/api/doc) objects as input and can have any type of output.
2. **Pipe subclass:** A subclass of [`Pipe`](/api/pipe) that implements at least
two methods: [`Pipe.predict`](/api/pipe#predict) and
[`Pipe.set_annotations`](/api/pipe#set_annotations).
2. **TrainablePipe subclass:** A subclass of [`TrainablePipe`](/api/pipe) that
implements at least two methods: [`TrainablePipe.predict`](/api/pipe#predict)
and [`TrainablePipe.set_annotations`](/api/pipe#set_annotations).
3. **Component factory:** A component factory registered with
[`@Language.factory`](/api/language#factory) that takes the `nlp` object and
component `name` and optional settings provided by the config and returns an
@ -1194,10 +1194,10 @@ You'll need the following:
> #### Example
>
> ```python
> from spacy.pipeline import Pipe
> from spacy.pipeline import TrainablePipe
> from spacy.language import Language
>
> class TrainableComponent(Pipe):
> class TrainableComponent(TrainablePipe):
> def predict(self, docs):
> ...
>
@ -1214,11 +1214,11 @@ You'll need the following:
| [`predict`](/api/pipe#predict) | Apply the component's model to a batch of [`Doc`](/api/doc) objects (without modifying them) and return the scores. |
| [`set_annotations`](/api/pipe#set_annotations) | Modify a batch of [`Doc`](/api/doc) objects, using pre-computed scores generated by `predict`. |
By default, [`Pipe.__init__`](/api/pipe#init) takes the shared vocab, the
[`Model`](https://thinc.ai/docs/api-model) and the name of the component
By default, [`TrainablePipe.__init__`](/api/pipe#init) takes the shared vocab,
the [`Model`](https://thinc.ai/docs/api-model) and the name of the component
instance in the pipeline, which you can use as a key in the losses. All other
keyword arguments will become available as [`Pipe.cfg`](/api/pipe#cfg) and will
also be serialized with the component.
keyword arguments will become available as [`TrainablePipe.cfg`](/api/pipe#cfg)
and will also be serialized with the component.
<Accordion title="Why components should be passed a Model instance, not create it" spaced>

View File

@ -178,7 +178,8 @@ freely combine implementations from different frameworks into a single model.
- **Thinc: **
[Wrapping PyTorch, TensorFlow & MXNet](https://thinc.ai/docs/usage-frameworks),
[`Model` API](https://thinc.ai/docs/api-model)
- **API:** [Model architectures](/api/architectures), [`Pipe`](/api/pipe)
- **API:** [Model architectures](/api/architectures),
[`TrainablePipe`](/api/pipe)
</Infobox>
@ -428,7 +429,7 @@ The following methods, attributes and commands are new in spaCy v3.0.
| [`Language.config`](/api/language#config) | The [config](/usage/training#config) used to create the current `nlp` object. An instance of [`Config`](https://thinc.ai/docs/api-config#config) and can be saved to disk and used for training. |
| [`Language.components`](/api/language#attributes), [`Language.component_names`](/api/language#attributes) | All available components and component names, including disabled components that are not run as part of the pipeline. |
| [`Language.disabled`](/api/language#attributes) | Names of disabled components that are not run as part of the pipeline. |
| [`Pipe.score`](/api/pipe#score) | Method on pipeline components that returns a dictionary of evaluation scores. |
| [`TrainablePipe.score`](/api/pipe#score) | Method on pipeline components that returns a dictionary of evaluation scores. |
| [`registry`](/api/top-level#registry) | Function registry to map functions to string names that can be referenced in [configs](/usage/training#config). |
| [`util.load_meta`](/api/top-level#util.load_meta), [`util.load_config`](/api/top-level#util.load_config) | Updated helpers for loading a pipeline's [`meta.json`](/api/data-formats#meta) and [`config.cfg`](/api/data-formats#config). |
| [`util.get_installed_models`](/api/top-level#util.get_installed_models) | Names of all pipeline packages installed in the environment. |
@ -483,7 +484,7 @@ format for documenting argument and return types.
[`Morphologizer`](/api/morphologizer),
[`AttributeRuler`](/api/attributeruler),
[`SentenceRecognizer`](/api/sentencerecognizer),
[`DependencyMatcher`](/api/dependencymatcher), [`Pipe`](/api/pipe),
[`DependencyMatcher`](/api/dependencymatcher), [`TrainablePipe`](/api/pipe),
[`Corpus`](/api/corpus)
</Infobox>
@ -522,7 +523,7 @@ Note that spaCy v3.0 now requires **Python 3.6+**.
[`@Language.factory`](/api/language#factory) decorator.
- The [`Language.update`](/api/language#update),
[`Language.evaluate`](/api/language#evaluate) and
[`Pipe.update`](/api/pipe#update) methods now all take batches of
[`TrainablePipe.update`](/api/pipe#update) methods now all take batches of
[`Example`](/api/example) objects instead of `Doc` and `GoldParse` objects, or
raw text and a dictionary of annotations.
- The `begin_training` methods have been renamed to `initialize` and now take a
@ -947,7 +948,7 @@ annotations = {"entities": [(0, 15, "PERSON"), (30, 38, "ORG")]}
The [`Language.update`](/api/language#update),
[`Language.evaluate`](/api/language#evaluate) and
[`Pipe.update`](/api/pipe#update) methods now all take batches of
[`TrainablePipe.update`](/api/pipe#update) methods now all take batches of
[`Example`](/api/example) objects instead of `Doc` and `GoldParse` objects, or
raw text and a dictionary of annotations.
@ -967,12 +968,13 @@ for i in range(20):
nlp.update(examples)
```
`Language.begin_training` and `Pipe.begin_training` have been renamed to
[`Language.initialize`](/api/language#initialize) and
[`Pipe.initialize`](/api/pipe#initialize), and the methods now take a function
that returns a sequence of `Example` objects to initialize the model instead of
a list of tuples. The data examples are used to **initialize the models** of
trainable pipeline components, which includes validating the network,
`Language.begin_training` and `TrainablePipe.begin_training` have been renamed
to [`Language.initialize`](/api/language#initialize) and
[`TrainablePipe.initialize`](/api/pipe#initialize), and the methods now take a
function that returns a sequence of `Example` objects to initialize the model
instead of a list of tuples. The data examples are used to **initialize the
models** of trainable pipeline components, which includes validating the
network,
[inferring missing shapes](https://thinc.ai/docs/usage-models#validation) and
setting up the label scheme.