Update docstrings, docs and pipe consistency

This commit is contained in:
Ines Montani 2020-07-28 13:37:31 +02:00
parent 0094cb0d04
commit ae4d8a6ffd
19 changed files with 955 additions and 133 deletions

View File

@ -1100,13 +1100,12 @@ class Language:
return scorer.score(examples) return scorer.score(examples)
@contextmanager @contextmanager
def use_params(self, params: dict, **cfg): def use_params(self, params: dict):
"""Replace weights of models in the pipeline with those provided in the """Replace weights of models in the pipeline with those provided in the
params dictionary. Can be used as a contextmanager, in which case, params dictionary. Can be used as a contextmanager, in which case,
models go back to their original weights after the block. models go back to their original weights after the block.
params (dict): A dictionary of parameters keyed by model ID. params (dict): A dictionary of parameters keyed by model ID.
**cfg: Config parameters.
EXAMPLE: EXAMPLE:
>>> with nlp.use_params(optimizer.averages): >>> with nlp.use_params(optimizer.averages):

View File

@ -128,7 +128,8 @@ class EntityLinker(Pipe):
def begin_training( def begin_training(
self, self,
get_examples: Callable = lambda: [], get_examples: Callable[[], Iterable[Example]] = lambda: [],
*,
pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None, pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None,
sgd: Optional[Optimizer] = None, sgd: Optional[Optimizer] = None,
) -> Optimizer: ) -> Optimizer:
@ -273,7 +274,7 @@ class EntityLinker(Pipe):
stream (Iterable[Doc]): A stream of documents. stream (Iterable[Doc]): A stream of documents.
batch_size (int): The number of documents to buffer. batch_size (int): The number of documents to buffer.
YIELDS (Doc): PRocessed documents in order. YIELDS (Doc): Processed documents in order.
DOCS: https://spacy.io/api/entitylinker#pipe DOCS: https://spacy.io/api/entitylinker#pipe
""" """

View File

@ -97,7 +97,7 @@ class Morphologizer(Tagger):
"""Add a new label to the pipe. """Add a new label to the pipe.
label (str): The label to add. label (str): The label to add.
RETURNS (int): 1 RETURNS (int): 0 if label is already present, otherwise 1.
DOCS: https://spacy.io/api/morphologizer#add_label DOCS: https://spacy.io/api/morphologizer#add_label
""" """

View File

@ -8,41 +8,51 @@ from ..errors import Errors
from .. import util from .. import util
def deserialize_config(path):
if path.exists():
return srsly.read_json(path)
else:
return {}
class Pipe: class Pipe:
"""This class is not instantiated directly. Components inherit from it, and """This class is a base class and not instantiated directly. Trainable
it defines the interface that components should follow to function as pipeline components like the EntityRecognizer or TextCategorizer inherit
components in a spaCy analysis pipeline. from it and it defines the interface that components should follow to
function as trainable components in a spaCy pipeline.
DOCS: https://spacy.io/api/pipe
""" """
name = None name = None
def __init__(self, vocab, model, name, **cfg): def __init__(self, vocab, model, name, **cfg):
"""Create a new pipe instance.""" """Initialize a pipeline component.
vocab (Vocab): The shared vocabulary.
model (thinc.api.Model): The Thinc Model powering the pipeline component.
name (str): The component instance name.
**cfg: Additonal settings and config parameters.
DOCS: https://spacy.io/api/pipe#init
"""
raise NotImplementedError raise NotImplementedError
def __call__(self, Doc doc): def __call__(self, Doc doc):
"""Apply the pipe to one document. The document is """Add context-sensitive embeddings to the Doc.tensor attribute.
modified in-place, and returned.
Both __call__ and pipe should delegate to the `predict()` docs (Doc): The Doc to preocess.
and `set_annotations()` methods. RETURNS (Doc): The processed Doc.
DOCS: https://spacy.io/api/pipe#call
""" """
scores = self.predict([doc]) scores = self.predict([doc])
self.set_annotations([doc], scores) self.set_annotations([doc], scores)
return doc return doc
def pipe(self, stream, batch_size=128): def pipe(self, stream, *, batch_size=128):
"""Apply the pipe to a stream of documents. """Apply the pipe to a stream of documents. This usually happens under
the hood when the nlp object is called on a text and all components are
applied to the Doc.
Both __call__ and pipe should delegate to the `predict()` stream (Iterable[Doc]): A stream of documents.
and `set_annotations()` methods. batch_size (int): The number of documents to buffer.
YIELDS (Doc): Processed documents in order.
DOCS: https://spacy.io/api/pipe#pipe
""" """
for docs in util.minibatch(stream, size=batch_size): for docs in util.minibatch(stream, size=batch_size):
scores = self.predict(docs) scores = self.predict(docs)
@ -50,38 +60,90 @@ class Pipe:
yield from docs yield from docs
def predict(self, docs): def predict(self, docs):
"""Apply the pipeline's model to a batch of docs, without """Apply the pipeline's model to a batch of docs, without modifying them.
modifying them. Returns a single tensor for a batch of documents.
docs (Iterable[Doc]): The documents to predict.
RETURNS: Vector representations for each token in the documents.
DOCS: https://spacy.io/api/pipe#predict
""" """
raise NotImplementedError raise NotImplementedError
def set_annotations(self, docs, scores): def set_annotations(self, docs, scores):
"""Modify a batch of documents, using pre-computed scores.""" """Modify a batch of documents, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify.
tokvecses: The tensors to set, produced by Pipe.predict.
DOCS: https://spacy.io/api/pipe#predict
"""
raise NotImplementedError raise NotImplementedError
def rehearse(self, examples, sgd=None, losses=None, **config): def rehearse(self, examples, *, sgd=None, losses=None, **config):
"""Perform a "rehearsal" update from a batch of data. Rehearsal updates
teach the current model to make predictions similar to an initial model,
to try to address the "catastrophic forgetting" problem. This feature is
experimental.
examples (Iterable[Example]): A batch of Example objects.
drop (float): The dropout rate.
sgd (thinc.api.Optimizer): The optimizer.
losses (Dict[str, float]): Optional record of the loss during training.
Updated using the component name as the key.
RETURNS (Dict[str, float]): The updated losses dictionary.
DOCS: https://spacy.io/api/pipe#rehearse
"""
pass pass
def get_loss(self, examples, scores): def get_loss(self, examples, scores):
"""Find the loss and gradient of loss for the batch of """Find the loss and gradient of loss for the batch of documents and
examples (with embedded docs) and their predicted scores.""" their predicted scores.
examples (Iterable[Examples]): The batch of examples.
scores: Scores representing the model's predictions.
RETUTNRS (Tuple[float, float]): The loss and the gradient.
DOCS: https://spacy.io/api/pipe#get_loss
"""
raise NotImplementedError raise NotImplementedError
def add_label(self, label): def add_label(self, label):
"""Add an output label, to be predicted by the model. """Add an output label, to be predicted by the model. It's possible to
extend pretrained models with new labels, but care should be taken to
avoid the "catastrophic forgetting" problem.
It's possible to extend pretrained models with new labels, label (str): The label to add.
but care should be taken to avoid the "catastrophic forgetting" RETURNS (int): 0 if label is already present, otherwise 1.
problem.
DOCS: https://spacy.io/api/pipe#add_label
""" """
raise NotImplementedError raise NotImplementedError
def create_optimizer(self): def create_optimizer(self):
"""Create an optimizer for the pipeline component.
RETURNS (thinc.api.Optimizer): The optimizer.
DOCS: https://spacy.io/api/pipe#create_optimizer
"""
return create_default_optimizer() return create_default_optimizer()
def begin_training(self, get_examples=lambda: [], *, pipeline=None, sgd=None): def begin_training(self, get_examples=lambda: [], *, pipeline=None, sgd=None):
"""Initialize the pipe for training, using data exampes if available. """Initialize the pipe for training, using data examples if available.
If no model has been initialized yet, the model is added."""
get_examples (Callable[[], Iterable[Example]]): Optional function that
returns gold-standard Example objects.
pipeline (List[Tuple[str, Callable]]): Optional list of pipeline
components that this component is part of. Corresponds to
nlp.pipeline.
sgd (thinc.api.Optimizer): Optional optimizer. Will be created with
create_optimizer if it doesn't exist.
RETURNS (thinc.api.Optimizer): The optimizer.
DOCS: https://spacy.io/api/pipe#begin_training
"""
self.model.initialize() self.model.initialize()
if hasattr(self, "vocab"): if hasattr(self, "vocab"):
link_vectors_to_models(self.vocab) link_vectors_to_models(self.vocab)
@ -90,6 +152,7 @@ class Pipe:
return sgd return sgd
def set_output(self, nO): def set_output(self, nO):
# TODO: document this across components?
if self.model.has_dim("nO") is not False: if self.model.has_dim("nO") is not False:
self.model.set_dim("nO", nO) self.model.set_dim("nO", nO)
if self.model.has_ref("output_layer"): if self.model.has_ref("output_layer"):
@ -99,6 +162,7 @@ class Pipe:
"""Get non-zero gradients of the model's parameters, as a dictionary """Get non-zero gradients of the model's parameters, as a dictionary
keyed by the parameter ID. The values are (weights, gradients) tuples. keyed by the parameter ID. The values are (weights, gradients) tuples.
""" """
# TODO: How is this used?
gradients = {} gradients = {}
queue = [self.model] queue = [self.model]
seen = set() seen = set()
@ -113,18 +177,33 @@ class Pipe:
return gradients return gradients
def use_params(self, params): def use_params(self, params):
"""Modify the pipe's model, to use the given parameter values.""" """Modify the pipe's model, to use the given parameter values. At the
end of the context, the original parameters are restored.
params (dict): The parameter values to use in the model.
DOCS: https://spacy.io/api/pipe#use_params
"""
with self.model.use_params(params): with self.model.use_params(params):
yield yield
def score(self, examples, **kwargs): def score(self, examples, **kwargs):
"""Score a batch of examples.
examples (Iterable[Example]): The examples to score.
RETURNS (Dict[str, Any]): The scores.
DOCS: https://spacy.io/api/pipe#score
"""
return {} return {}
def to_bytes(self, exclude=tuple()): def to_bytes(self, exclude=tuple()):
"""Serialize the pipe to a bytestring. """Serialize the pipe to a bytestring.
exclude (list): String names of serialization fields to exclude. exclude (Iterable[str]): String names of serialization fields to exclude.
RETURNS (bytes): The serialized object. RETURNS (bytes): The serialized object.
DOCS: https://spacy.io/api/pipe#to_bytes
""" """
serialize = {} serialize = {}
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg) serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
@ -134,7 +213,13 @@ class Pipe:
return util.to_bytes(serialize, exclude) return util.to_bytes(serialize, exclude)
def from_bytes(self, bytes_data, exclude=tuple()): def from_bytes(self, bytes_data, exclude=tuple()):
"""Load the pipe from a bytestring.""" """Load the pipe from a bytestring.
exclude (Iterable[str]): String names of serialization fields to exclude.
RETURNS (Pipe): The loaded object.
DOCS: https://spacy.io/api/pipe#from_bytes
"""
def load_model(b): def load_model(b):
try: try:
@ -151,7 +236,13 @@ class Pipe:
return self return self
def to_disk(self, path, exclude=tuple()): def to_disk(self, path, exclude=tuple()):
"""Serialize the pipe to disk.""" """Serialize the pipe to disk.
path (str / Path): Path to a directory.
exclude (Iterable[str]): String names of serialization fields to exclude.
DOCS: https://spacy.io/api/pipe#to_disk
"""
serialize = {} serialize = {}
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg) serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
serialize["vocab"] = lambda p: self.vocab.to_disk(p) serialize["vocab"] = lambda p: self.vocab.to_disk(p)
@ -159,7 +250,14 @@ class Pipe:
util.to_disk(path, serialize, exclude) util.to_disk(path, serialize, exclude)
def from_disk(self, path, exclude=tuple()): def from_disk(self, path, exclude=tuple()):
"""Load the pipe from disk.""" """Load the pipe from disk.
path (str / Path): Path to a directory.
exclude (Iterable[str]): String names of serialization fields to exclude.
RETURNS (Pipe): The loaded object.
DOCS: https://spacy.io/api/pipe#from_disk
"""
def load_model(p): def load_model(p):
try: try:
@ -173,3 +271,10 @@ class Pipe:
deserialize["model"] = load_model deserialize["model"] = load_model
util.from_disk(path, deserialize, exclude) util.from_disk(path, deserialize, exclude)
return self return self
def deserialize_config(path):
if path.exists():
return srsly.read_json(path)
else:
return {}

View File

@ -329,7 +329,7 @@ class Tagger(Pipe):
label (str): The label to add. label (str): The label to add.
values (Dict[int, str]): Optional values to map to the label, e.g. a values (Dict[int, str]): Optional values to map to the label, e.g. a
tag map dictionary. tag map dictionary.
RETURNS (int): 1 RETURNS (int): 0 if label is already present, otherwise 1.
DOCS: https://spacy.io/api/tagger#add_label DOCS: https://spacy.io/api/tagger#add_label
""" """
@ -355,10 +355,6 @@ class Tagger(Pipe):
self.vocab.morphology.load_tag_map(tag_map) self.vocab.morphology.load_tag_map(tag_map)
return 1 return 1
def use_params(self, params):
with self.model.use_params(params):
yield
def score(self, examples, **kwargs): def score(self, examples, **kwargs):
"""Score a batch of examples. """Score a batch of examples.

View File

@ -56,7 +56,17 @@ dropout = null
"textcat", "textcat",
assigns=["doc.cats"], assigns=["doc.cats"],
default_config={"labels": [], "model": DEFAULT_TEXTCAT_MODEL}, default_config={"labels": [], "model": DEFAULT_TEXTCAT_MODEL},
scores=["cats_score", "cats_score_desc", "cats_p", "cats_r", "cats_f", "cats_macro_f", "cats_macro_auc", "cats_f_per_type", "cats_macro_auc_per_type"], scores=[
"cats_score",
"cats_score_desc",
"cats_p",
"cats_r",
"cats_f",
"cats_macro_f",
"cats_macro_auc",
"cats_f_per_type",
"cats_macro_auc_per_type",
],
default_score_weights={"cats_score": 1.0}, default_score_weights={"cats_score": 1.0},
) )
def make_textcat( def make_textcat(
@ -120,7 +130,7 @@ class TextCategorizer(Pipe):
stream (Iterable[Doc]): A stream of documents. stream (Iterable[Doc]): A stream of documents.
batch_size (int): The number of documents to buffer. batch_size (int): The number of documents to buffer.
YIELDS (Doc): PRocessed documents in order. YIELDS (Doc): Processed documents in order.
DOCS: https://spacy.io/api/textcategorizer#pipe DOCS: https://spacy.io/api/textcategorizer#pipe
""" """
@ -288,7 +298,7 @@ class TextCategorizer(Pipe):
"""Add a new label to the pipe. """Add a new label to the pipe.
label (str): The label to add. label (str): The label to add.
RETURNS (int): 1. RETURNS (int): 0 if label is already present, otherwise 1.
DOCS: https://spacy.io/api/textcategorizer#add_label DOCS: https://spacy.io/api/textcategorizer#add_label
""" """

View File

@ -34,10 +34,13 @@ def make_tok2vec(nlp: Language, name: str, model: Model) -> "Tok2Vec":
class Tok2Vec(Pipe): class Tok2Vec(Pipe):
def __init__(self, vocab: Vocab, model: Model, name: str = "tok2vec") -> None: def __init__(self, vocab: Vocab, model: Model, name: str = "tok2vec") -> None:
"""Construct a new statistical model. Weights are not allocated on """Initialize a tok2vec component.
initialisation.
vocab (Vocab): A `Vocab` instance. The model must share the same `Vocab` vocab (Vocab): The shared vocabulary.
instance with the `Doc` objects it will process. model (thinc.api.Model): The Thinc Model powering the pipeline component.
name (str): The component instance name.
DOCS: https://spacy.io/api/tok2vec#init
""" """
self.vocab = vocab self.vocab = vocab
self.model = model self.model = model
@ -57,20 +60,27 @@ class Tok2Vec(Pipe):
self.add_listener(node) self.add_listener(node)
def __call__(self, doc: Doc) -> Doc: def __call__(self, doc: Doc) -> Doc:
"""Add context-sensitive vectors to a `Doc`, e.g. from a CNN or LSTM """Add context-sensitive embeddings to the Doc.tensor attribute.
model. Vectors are set to the `Doc.tensor` attribute.
docs (Doc or iterable): One or more documents to add vectors to. docs (Doc): The Doc to preocess.
RETURNS (dict or None): Intermediate computations. RETURNS (Doc): The processed Doc.
DOCS: https://spacy.io/api/tok2vec#call
""" """
tokvecses = self.predict([doc]) tokvecses = self.predict([doc])
self.set_annotations([doc], tokvecses) self.set_annotations([doc], tokvecses)
return doc return doc
def pipe(self, stream: Iterator[Doc], batch_size: int = 128) -> Iterator[Doc]: def pipe(self, stream: Iterator[Doc], *, batch_size: int = 128) -> Iterator[Doc]:
"""Process `Doc` objects as a stream. """Apply the pipe to a stream of documents. This usually happens under
stream (iterator): A sequence of `Doc` objects to process. the hood when the nlp object is called on a text and all components are
batch_size (int): Number of `Doc` objects to group. applied to the Doc.
YIELDS (iterator): A sequence of `Doc` objects, in order of input.
stream (Iterable[Doc]): A stream of documents.
batch_size (int): The number of documents to buffer.
YIELDS (Doc): Processed documents in order.
DOCS: https://spacy.io/api/tok2vec#pipe
""" """
for docs in minibatch(stream, batch_size): for docs in minibatch(stream, batch_size):
docs = list(docs) docs = list(docs)
@ -78,10 +88,14 @@ class Tok2Vec(Pipe):
self.set_annotations(docs, tokvecses) self.set_annotations(docs, tokvecses)
yield from docs yield from docs
def predict(self, docs: Sequence[Doc]): def predict(self, docs: Iterable[Doc]):
"""Return a single tensor for a batch of documents. """Apply the pipeline's model to a batch of docs, without modifying them.
docs (iterable): A sequence of `Doc` objects. Returns a single tensor for a batch of documents.
RETURNS (object): Vector representations for each token in the documents.
docs (Iterable[Doc]): The documents to predict.
RETURNS: Vector representations for each token in the documents.
DOCS: https://spacy.io/api/tok2vec#predict
""" """
tokvecs = self.model.predict(docs) tokvecs = self.model.predict(docs)
batch_id = Tok2VecListener.get_batch_id(docs) batch_id = Tok2VecListener.get_batch_id(docs)
@ -90,9 +104,12 @@ class Tok2Vec(Pipe):
return tokvecs return tokvecs
def set_annotations(self, docs: Sequence[Doc], tokvecses) -> None: def set_annotations(self, docs: Sequence[Doc], tokvecses) -> None:
"""Set the tensor attribute for a batch of documents. """Modify a batch of documents, using pre-computed scores.
docs (iterable): A sequence of `Doc` objects.
tokvecs (object): Vector representation for each token in the documents. docs (Iterable[Doc]): The documents to modify.
tokvecses: The tensors to set, produced by Tok2Vec.predict.
DOCS: https://spacy.io/api/tok2vec#predict
""" """
for doc, tokvecs in zip(docs, tokvecses): for doc, tokvecs in zip(docs, tokvecses):
assert tokvecs.shape[0] == len(doc) assert tokvecs.shape[0] == len(doc)
@ -107,13 +124,19 @@ class Tok2Vec(Pipe):
losses: Optional[Dict[str, float]] = None, losses: Optional[Dict[str, float]] = None,
set_annotations: bool = False, set_annotations: bool = False,
): ):
"""Update the model. """Learn from a batch of documents and gold-standard information,
examples (Iterable[Example]): A batch of examples updating the pipe's model.
drop (float): The droput rate.
sgd (Optimizer): An optimizer. examples (Iterable[Example]): A batch of Example objects.
losses (Dict[str, float]): Dictionary to update with the loss, keyed by component. drop (float): The dropout rate.
set_annotations (bool): whether or not to update the examples with the predictions set_annotations (bool): Whether or not to update the Example objects
RETURNS (Dict[str, float]): The updated losses dictionary with the predictions.
sgd (thinc.api.Optimizer): The optimizer.
losses (Dict[str, float]): Optional record of the loss during training.
Updated using the component name as the key.
RETURNS (Dict[str, float]): The updated losses dictionary.
DOCS: https://spacy.io/api/tok2vec#update
""" """
if losses is None: if losses is None:
losses = {} losses = {}
@ -122,7 +145,6 @@ class Tok2Vec(Pipe):
docs = [docs] docs = [docs]
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
tokvecs, bp_tokvecs = self.model.begin_update(docs) tokvecs, bp_tokvecs = self.model.begin_update(docs)
d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs] d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
losses.setdefault(self.name, 0.0) losses.setdefault(self.name, 0.0)
@ -156,14 +178,23 @@ class Tok2Vec(Pipe):
def begin_training( def begin_training(
self, self,
get_examples: Callable = lambda: [], get_examples: Callable[[], Iterable[Example]] = lambda: [],
*,
pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None, pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None,
sgd: Optional[Optimizer] = None, sgd: Optional[Optimizer] = None,
): ):
"""Allocate models and pre-process training data """Initialize the pipe for training, using data examples if available.
get_examples (function): Function returning example training data. get_examples (Callable[[], Iterable[Example]]): Optional function that
pipeline (list): The pipeline the model is part of. returns gold-standard Example objects.
pipeline (List[Tuple[str, Callable]]): Optional list of pipeline
components that this component is part of. Corresponds to
nlp.pipeline.
sgd (thinc.api.Optimizer): Optional optimizer. Will be created with
create_optimizer if it doesn't exist.
RETURNS (thinc.api.Optimizer): The optimizer.
DOCS: https://spacy.io/api/tok2vec#begin_training
""" """
docs = [Doc(Vocab(), words=["hello"])] docs = [Doc(Vocab(), words=["hello"])]
self.model.initialize(X=docs) self.model.initialize(X=docs)

View File

@ -123,6 +123,8 @@ cdef class Parser:
resized = True resized = True
if resized: if resized:
self._resize() self._resize()
return 1
return 0
def _resize(self): def _resize(self):
self.model.attrs["resize_output"](self.model, self.moves.n_moves) self.model.attrs["resize_output"](self.model, self.moves.n_moves)

View File

@ -1182,6 +1182,7 @@ VECTORS_KEY = "spacy_pretrained_vectors"
def create_default_optimizer() -> Optimizer: def create_default_optimizer() -> Optimizer:
# TODO: Do we still want to allow env_opt?
learn_rate = env_opt("learn_rate", 0.001) learn_rate = env_opt("learn_rate", 0.001)
beta1 = env_opt("optimizer_B1", 0.9) beta1 = env_opt("optimizer_B1", 0.9)
beta2 = env_opt("optimizer_B2", 0.999) beta2 = env_opt("optimizer_B2", 0.999)

View File

@ -248,19 +248,20 @@ component.
## DependencyParser.use_params {#use_params tag="method, contextmanager"} ## DependencyParser.use_params {#use_params tag="method, contextmanager"}
Modify the pipe's model, to use the given parameter values. Modify the pipe's model, to use the given parameter values. At the end of the
context, the original parameters are restored.
> #### Example > #### Example
> >
> ```python > ```python
> parser = DependencyParser(nlp.vocab) > parser = DependencyParser(nlp.vocab)
> with parser.use_params(): > with parser.use_params(optimizer.averages):
> parser.to_disk("/best_model") > parser.to_disk("/best_model")
> ``` > ```
| Name | Type | Description | | Name | Type | Description |
| -------- | ---- | ---------------------------------------------------------------------------------------------------------- | | -------- | ---- | ----------------------------------------- |
| `params` | - | The parameter values to use in the model. At the end of the context, the original parameters are restored. | | `params` | dict | The parameter values to use in the model. |
## DependencyParser.add_label {#add_label tag="method"} ## DependencyParser.add_label {#add_label tag="method"}
@ -273,9 +274,10 @@ Add a new label to the pipe.
> parser.add_label("MY_LABEL") > parser.add_label("MY_LABEL")
> ``` > ```
| Name | Type | Description | | Name | Type | Description |
| ------- | ---- | ----------------- | | ----------- | ---- | --------------------------------------------------- |
| `label` | str | The label to add. | | `label` | str | The label to add. |
| **RETURNS** | int | `0` if the label is already present, otherwise `1`. |
## DependencyParser.to_disk {#to_disk tag="method"} ## DependencyParser.to_disk {#to_disk tag="method"}

View File

@ -239,7 +239,8 @@ Create an optimizer for the pipeline component.
## EntityLinker.use_params {#use_params tag="method, contextmanager"} ## EntityLinker.use_params {#use_params tag="method, contextmanager"}
Modify the pipe's EL model, to use the given parameter values. Modify the pipe's model, to use the given parameter values. At the end of the
context, the original parameters are restored.
> #### Example > #### Example
> >
@ -249,9 +250,9 @@ Modify the pipe's EL model, to use the given parameter values.
> entity_linker.to_disk("/best_model") > entity_linker.to_disk("/best_model")
> ``` > ```
| Name | Type | Description | | Name | Type | Description |
| -------- | ---- | ---------------------------------------------------------------------------------------------------------- | | -------- | ---- | ----------------------------------------- |
| `params` | dict | The parameter values to use in the model. At the end of the context, the original parameters are restored. | | `params` | dict | The parameter values to use in the model. |
## EntityLinker.to_disk {#to_disk tag="method"} ## EntityLinker.to_disk {#to_disk tag="method"}

View File

@ -247,7 +247,8 @@ Create an optimizer for the pipeline component.
## EntityRecognizer.use_params {#use_params tag="method, contextmanager"} ## EntityRecognizer.use_params {#use_params tag="method, contextmanager"}
Modify the pipe's model, to use the given parameter values. Modify the pipe's model, to use the given parameter values. At the end of the
context, the original parameters are restored.
> #### Example > #### Example
> >
@ -257,9 +258,9 @@ Modify the pipe's model, to use the given parameter values.
> ner.to_disk("/best_model") > ner.to_disk("/best_model")
> ``` > ```
| Name | Type | Description | | Name | Type | Description |
| -------- | ---- | ---------------------------------------------------------------------------------------------------------- | | -------- | ---- | ----------------------------------------- |
| `params` | dict | The parameter values to use in the model. At the end of the context, the original parameters are restored. | | `params` | dict | The parameter values to use in the model. |
## EntityRecognizer.add_label {#add_label tag="method"} ## EntityRecognizer.add_label {#add_label tag="method"}
@ -272,9 +273,10 @@ Add a new label to the pipe.
> ner.add_label("MY_LABEL") > ner.add_label("MY_LABEL")
> ``` > ```
| Name | Type | Description | | Name | Type | Description |
| ------- | ---- | ----------------- | | ----------- | ---- | --------------------------------------------------- |
| `label` | str | The label to add. | | `label` | str | The label to add. |
| **RETURNS** | int | `0` if the label is already present, otherwise `1`. |
## EntityRecognizer.to_disk {#to_disk tag="method"} ## EntityRecognizer.to_disk {#to_disk tag="method"}

View File

@ -271,7 +271,6 @@ their original weights after the block.
| Name | Type | Description | | Name | Type | Description |
| -------- | ---- | --------------------------------------------- | | -------- | ---- | --------------------------------------------- |
| `params` | dict | A dictionary of parameters keyed by model ID. | | `params` | dict | A dictionary of parameters keyed by model ID. |
| `**cfg` | - | Config parameters. |
## Language.create_pipe {#create_pipe tag="method" new="2"} ## Language.create_pipe {#create_pipe tag="method" new="2"}

View File

@ -233,19 +233,20 @@ Create an optimizer for the pipeline component.
## Morphologizer.use_params {#use_params tag="method, contextmanager"} ## Morphologizer.use_params {#use_params tag="method, contextmanager"}
Modify the pipe's model, to use the given parameter values. Modify the pipe's model, to use the given parameter values. At the end of the
context, the original parameters are restored.
> #### Example > #### Example
> >
> ```python > ```python
> morphologizer = nlp.add_pipe("morphologizer") > morphologizer = nlp.add_pipe("morphologizer")
> with morphologizer.use_params(): > with morphologizer.use_params(optimizer.averages):
> morphologizer.to_disk("/best_model") > morphologizer.to_disk("/best_model")
> ``` > ```
| Name | Type | Description | | Name | Type | Description |
| -------- | ---- | ---------------------------------------------------------------------------------------------------------- | | -------- | ---- | ----------------------------------------- |
| `params` | - | The parameter values to use in the model. At the end of the context, the original parameters are restored. | | `params` | dict | The parameter values to use in the model. |
## Morphologizer.add_label {#add_label tag="method"} ## Morphologizer.add_label {#add_label tag="method"}
@ -259,9 +260,10 @@ both `pos` and `morph`, the label should include the UPOS as the feature `POS`.
> morphologizer.add_label("Mood=Ind|POS=VERB|Tense=Past|VerbForm=Fin") > morphologizer.add_label("Mood=Ind|POS=VERB|Tense=Past|VerbForm=Fin")
> ``` > ```
| Name | Type | Description | | Name | Type | Description |
| ------- | ---- | ----------------- | | ----------- | ---- | --------------------------------------------------- |
| `label` | str | The label to add. | | `label` | str | The label to add. |
| **RETURNS** | int | `0` if the label is already present, otherwise `1`. |
## Morphologizer.to_disk {#to_disk tag="method"} ## Morphologizer.to_disk {#to_disk tag="method"}

View File

@ -1,6 +1,381 @@
--- ---
title: Pipe title: Pipe
tag: class tag: class
teaser: Base class for trainable pipeline components
--- ---
TODO: write This class is a base class and **not instantiated directly**. Trainable pipeline
components like the [`EntityRecognizer`](/api/entityrecognizer) or
[`TextCategorizer`](/api/textcategorizer) inherit from it and it defines the
interface that components should follow to function as trainable components in a
spaCy pipeline.
```python
https://github.com/explosion/spaCy/blob/develop/spacy/pipeline/pipe.pyx
```
## Pipe.\_\_init\_\_ {#init tag="method"}
> #### Example
>
> ```python
> from spacy.pipeline import Pipe
> from spacy.language import Language
>
> class CustomPipe(Pipe):
> ...
>
> @Language.factory("your_custom_pipe", default_config={"model": MODEL})
> def make_custom_pipe(nlp, name, model):
> return CustomPipe(nlp.vocab, model, name)
> ```
Create a new pipeline instance. In your application, you would normally use a
shortcut for this and instantiate the component using its string name and
[`nlp.add_pipe`](/api/language#create_pipe).
<Infobox variant="danger">
This method needs to be overwritten with your own custom `__init__` method.
</Infobox>
| Name | Type | Description |
| ------- | ------------------------------------------ | ------------------------------------------------------------------------------------------- |
| `vocab` | `Vocab` | The shared vocabulary. |
| `model` | [`Model`](https://thinc.ai/docs/api-model) | The Thinc [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. |
| `name` | str | String name of the component instance. Used to add entries to the `losses` during training. |
| `**cfg` | | Additional config parameters and settings. |
## Pipe.\_\_call\_\_ {#call tag="method"}
Apply the pipe to one document. The document is modified in place, and returned.
This usually happens under the hood when the `nlp` object is called on a text
and all pipeline components are applied to the `Doc` in order. Both
[`__call__`](/api/pipe#call) and [`pipe`](/api/pipe#pipe) delegate to the
[`predict`](/api/pipe#predict) and
[`set_annotations`](/api/pipe#set_annotations) methods.
> #### Example
>
> ```python
> doc = nlp("This is a sentence.")
> pipe = nlp.add_pipe("your_custom_pipe")
> # This usually happens under the hood
> processed = pipe(doc)
> ```
| Name | Type | Description |
| ----------- | ----- | ------------------------ |
| `doc` | `Doc` | The document to process. |
| **RETURNS** | `Doc` | The processed document. |
## Pipe.pipe {#pipe tag="method"}
Apply the pipe to a stream of documents. This usually happens under the hood
when the `nlp` object is called on a text and all pipeline components are
applied to the `Doc` in order. Both [`__call__`](/api/pipe#call) and
[`pipe`](/api/pipe#pipe) delegate to the [`predict`](/api/pipe#predict) and
[`set_annotations`](/api/pipe#set_annotations) methods.
> #### Example
>
> ```python
> pipe = nlp.add_pipe("your_custom_pipe")
> for doc in pipe.pipe(docs, batch_size=50):
> pass
> ```
| Name | Type | Description |
| -------------- | --------------- | ----------------------------------------------------- |
| `stream` | `Iterable[Doc]` | A stream of documents. |
| _keyword-only_ | | |
| `batch_size` | int | The number of documents to buffer. Defaults to `128`. |
| **YIELDS** | `Doc` | The processed documents in order. |
## Pipe.begin_training {#begin_training tag="method"}
Initialize the pipe for training, using data examples if available. Return an
[`Optimizer`](https://thinc.ai/docs/api-optimizers) object.
> #### Example
>
> ```python
> pipe = nlp.add_pipe("your_custom_pipe")
> optimizer = pipe.begin_training(pipeline=nlp.pipeline)
> ```
| Name | Type | Description |
| -------------- | --------------------------------------------------- | ---------------------------------------------------------------------------------------------------------- |
| `get_examples` | `Callable[[], Iterable[Example]]` | Optional function that returns gold-standard annotations in the form of [`Example`](/api/example) objects. |
| _keyword-only_ | | |
| `pipeline` | `List[Tuple[str, Callable]]` | Optional list of pipeline components that this component is part of. |
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | An optional optimizer. Will be created via [`create_optimizer`](/api/pipe#create_optimizer) if not set. |
| **RETURNS** | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
## Pipe.predict {#predict tag="method"}
Apply the pipeline's model to a batch of docs, without modifying them.
<Infobox variant="danger">
This method needs to be overwritten with your own custom `predict` method.
</Infobox>
> #### Example
>
> ```python
> pipe = nlp.add_pipe("your_custom_pipe")
> scores = pipe.predict([doc1, doc2])
> ```
| Name | Type | Description |
| ----------- | --------------- | ----------------------------------------- |
| `docs` | `Iterable[Doc]` | The documents to predict. |
| **RETURNS** | - | The model's prediction for each document. |
## Pipe.set_annotations {#set_annotations tag="method"}
Modify a batch of documents, using pre-computed scores.
<Infobox variant="danger">
This method needs to be overwritten with your own custom `set_annotations`
method.
</Infobox>
> #### Example
>
> ```python
> pipe = nlp.add_pipe("your_custom_pipe")
> scores = pipe.predict(docs)
> pipe.set_annotations(docs, scores)
> ```
| Name | Type | Description |
| -------- | --------------- | ---------------------------------------------- |
| `docs` | `Iterable[Doc]` | The documents to modify. |
| `scores` | - | The scores to set, produced by `Pipe.predict`. |
## Pipe.update {#update tag="method"}
Learn from a batch of documents and gold-standard information, updating the
pipe's model. Delegates to [`predict`](/api/pipe#predict).
<Infobox variant="danger">
This method needs to be overwritten with your own custom `update` method.
</Infobox>
> #### Example
>
> ```python
> pipe = nlp.add_pipe("your_custom_pipe")
> optimizer = nlp.begin_training()
> losses = pipe.update(examples, sgd=optimizer)
> ```
| Name | Type | Description |
| ----------------- | --------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------- |
| `examples` | `Iterable[Example]` | A batch of [`Example`](/api/example) objects to learn from. |
| _keyword-only_ | | |
| `drop` | float | The dropout rate. |
| `set_annotations` | bool | Whether or not to update the `Example` objects with the predictions, delegating to [`set_annotations`](/api/pipe#set_annotations). |
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
| `losses` | `Dict[str, float]` | Optional record of the loss during training. Updated using the component name as the key. |
| **RETURNS** | `Dict[str, float]` | The updated `losses` dictionary. |
## Pipe.rehearse {#rehearse tag="method,experimental"}
Perform a "rehearsal" update from a batch of data. Rehearsal updates teach the
current model to make predictions similar to an initial model, to try to address
the "catastrophic forgetting" problem. This feature is experimental.
> #### Example
>
> ```python
> pipe = nlp.add_pipe("your_custom_pipe")
> optimizer = nlp.begin_training()
> losses = pipe.rehearse(examples, sgd=optimizer)
> ```
| Name | Type | Description |
| -------------- | --------------------------------------------------- | ----------------------------------------------------------------------------------------- |
| `examples` | `Iterable[Example]` | A batch of [`Example`](/api/example) objects to learn from. |
| _keyword-only_ | | |
| `drop` | float | The dropout rate. |
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
| `losses` | `Dict[str, float]` | Optional record of the loss during training. Updated using the component name as the key. |
| **RETURNS** | `Dict[str, float]` | The updated `losses` dictionary. |
## Pipe.get_loss {#get_loss tag="method"}
Find the loss and gradient of loss for the batch of documents and their
predicted scores.
> #### Example
>
> ```python
> ner = nlp.add_pipe("ner")
> scores = ner.predict([eg.predicted for eg in examples])
> loss, d_loss = ner.get_loss(examples, scores)
> ```
| Name | Type | Description |
| ----------- | --------------------- | --------------------------------------------------- |
| `examples` | `Iterable[Example]` | The batch of examples. |
| `scores` | | Scores representing the model's predictions. |
| **RETURNS** | `Tuple[float, float]` | The loss and the gradient, i.e. `(loss, gradient)`. |
## Pipe.score {#score tag="method" new="3"}
Score a batch of examples.
> #### Example
>
> ```python
> scores = pipe.score(examples)
> ```
| Name | Type | Description |
| ----------- | ------------------- | --------------------------------------------------------- |
| `examples` | `Iterable[Example]` | The examples to score. |
| **RETURNS** | `Dict[str, Any]` | The scores, e.g. produced by the [`Scorer`](/api/scorer). |
## Pipe.create_optimizer {#create_optimizer tag="method"}
Create an optimizer for the pipeline component. Defaults to
[`Adam`](https://thinc.ai/docs/api-optimizers#adam) with default settings.
> #### Example
>
> ```python
> pipe = nlp.add_pipe("your_custom_pipe")
> optimizer = pipe.create_optimizer()
> ```
| Name | Type | Description |
| ----------- | --------------------------------------------------- | -------------- |
| **RETURNS** | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
## Pipe.add_label {#add_label tag="method"}
Add a new label to the pipe. It's possible to extend pretrained models with new
labels, but care should be taken to avoid the "catastrophic forgetting" problem.
> #### Example
>
> ```python
> pipe = nlp.add_pipe("your_custom_pipe")
> pipe.add_label("MY_LABEL")
> ```
| Name | Type | Description |
| ----------- | ---- | --------------------------------------------------- |
| `label` | str | The label to add. |
| **RETURNS** | int | `0` if the label is already present, otherwise `1`. |
## Pipe.use_params {#use_params tag="method, contextmanager"}
Modify the pipe's model, to use the given parameter values. At the end of the
context, the original parameters are restored.
> #### Example
>
> ```python
> pipe = nlp.add_pipe("your_custom_pipe")
> with pipe.use_params(optimizer.averages):
> pipe.to_disk("/best_model")
> ```
| Name | Type | Description |
| -------- | ---- | ----------------------------------------- |
| `params` | dict | The parameter values to use in the model. |
## Pipe.to_disk {#to_disk tag="method"}
Serialize the pipe to disk.
> #### Example
>
> ```python
> pipe = nlp.add_pipe("your_custom_pipe")
> pipe.to_disk("/path/to/pipe")
> ```
| Name | Type | Description |
| --------- | --------------- | --------------------------------------------------------------------------------------------------------------------- |
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
## Pipe.from_disk {#from_disk tag="method"}
Load the pipe from disk. Modifies the object in place and returns it.
> #### Example
>
> ```python
> pipe = nlp.add_pipe("your_custom_pipe")
> pipe.from_disk("/path/to/pipe")
> ```
| Name | Type | Description |
| ----------- | --------------- | -------------------------------------------------------------------------- |
| `path` | str / `Path` | A path to a directory. Paths may be either strings or `Path`-like objects. |
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
| **RETURNS** | `Pipe` | The modified pipe. |
## Pipe.to_bytes {#to_bytes tag="method"}
> #### Example
>
> ```python
> pipe = nlp.add_pipe("your_custom_pipe")
> pipe_bytes = pipe.to_bytes()
> ```
Serialize the pipe to a bytestring.
| Name | Type | Description |
| ----------- | --------------- | ------------------------------------------------------------------------- |
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
| **RETURNS** | bytes | The serialized form of the pipe. |
## Pipe.from_bytes {#from_bytes tag="method"}
Load the pipe from a bytestring. Modifies the object in place and returns it.
> #### Example
>
> ```python
> pipe_bytes = pipe.to_bytes()
> pipe = nlp.add_pipe("your_custom_pipe")
> pipe.from_bytes(pipe_bytes)
> ```
| Name | Type | Description |
| ------------ | --------------- | ------------------------------------------------------------------------- |
| `bytes_data` | bytes | The data to load from. |
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
| **RETURNS** | `Pipe` | The pipe. |
## Serialization fields {#serialization-fields}
During serialization, spaCy will export several data fields used to restore
different aspects of the object. If needed, you can exclude them from
serialization by passing in the string names via the `exclude` argument.
> #### Example
>
> ```python
> data = pipe.to_disk("/path", exclude=["vocab"])
> ```
| Name | Description |
| ------- | -------------------------------------------------------------- |
| `vocab` | The shared [`Vocab`](/api/vocab). |
| `cfg` | The config file. You usually don't want to exclude this. |
| `model` | The binary model data. You usually don't want to exclude this. |

View File

@ -265,19 +265,20 @@ Create an optimizer for the pipeline component.
## SentenceRecognizer.use_params {#use_params tag="method, contextmanager"} ## SentenceRecognizer.use_params {#use_params tag="method, contextmanager"}
Modify the pipe's model, to use the given parameter values. Modify the pipe's model, to use the given parameter values. At the end of the
context, the original parameters are restored.
> #### Example > #### Example
> >
> ```python > ```python
> senter = nlp.add_pipe("senter") > senter = nlp.add_pipe("senter")
> with senter.use_params(): > with senter.use_params(optimizer.averages):
> senter.to_disk("/best_model") > senter.to_disk("/best_model")
> ``` > ```
| Name | Type | Description | | Name | Type | Description |
| -------- | ---- | ---------------------------------------------------------------------------------------------------------- | | -------- | ---- | ----------------------------------------- |
| `params` | - | The parameter values to use in the model. At the end of the context, the original parameters are restored. | | `params` | dict | The parameter values to use in the model. |
## SentenceRecognizer.to_disk {#to_disk tag="method"} ## SentenceRecognizer.to_disk {#to_disk tag="method"}

View File

@ -263,19 +263,20 @@ Create an optimizer for the pipeline component.
## Tagger.use_params {#use_params tag="method, contextmanager"} ## Tagger.use_params {#use_params tag="method, contextmanager"}
Modify the pipe's model, to use the given parameter values. Modify the pipe's model, to use the given parameter values. At the end of the
context, the original parameters are restored.
> #### Example > #### Example
> >
> ```python > ```python
> tagger = nlp.add_pipe("tagger") > tagger = nlp.add_pipe("tagger")
> with tagger.use_params(): > with tagger.use_params(optimizer.averages):
> tagger.to_disk("/best_model") > tagger.to_disk("/best_model")
> ``` > ```
| Name | Type | Description | | Name | Type | Description |
| -------- | ---- | ---------------------------------------------------------------------------------------------------------- | | -------- | ---- | ----------------------------------------- |
| `params` | - | The parameter values to use in the model. At the end of the context, the original parameters are restored. | | `params` | dict | The parameter values to use in the model. |
## Tagger.add_label {#add_label tag="method"} ## Tagger.add_label {#add_label tag="method"}
@ -289,10 +290,11 @@ Add a new label to the pipe.
> tagger.add_label("MY_LABEL", {POS: "NOUN"}) > tagger.add_label("MY_LABEL", {POS: "NOUN"})
> ``` > ```
| Name | Type | Description | | Name | Type | Description |
| -------- | ---------------- | --------------------------------------------------------------- | | ----------- | ---------------- | --------------------------------------------------------------- |
| `label` | str | The label to add. | | `label` | str | The label to add. |
| `values` | `Dict[int, str]` | Optional values to map to the label, e.g. a tag map dictionary. | | `values` | `Dict[int, str]` | Optional values to map to the label, e.g. a tag map dictionary. |
| **RETURNS** | int | `0` if the label is already present, otherwise `1`. |
## Tagger.to_disk {#to_disk tag="method"} ## Tagger.to_disk {#to_disk tag="method"}

View File

@ -262,7 +262,8 @@ Score a batch of examples.
| Name | Type | Description | | Name | Type | Description |
| ---------------- | ------------------- | ---------------------------------------------------------------------- | | ---------------- | ------------------- | ---------------------------------------------------------------------- |
| `examples` | `Iterable[Example]` | The examples to score. | _keyword-only_ | | | | `examples` | `Iterable[Example]` | The examples to score. |
| _keyword-only_ | | |
| `positive_label` | str | Optional positive label. | | `positive_label` | str | Optional positive label. |
| **RETURNS** | `Dict[str, Any]` | The scores, produced by [`Scorer.score_cats`](/api/scorer#score_cats). | | **RETURNS** | `Dict[str, Any]` | The scores, produced by [`Scorer.score_cats`](/api/scorer#score_cats). |
@ -292,9 +293,10 @@ Add a new label to the pipe.
> textcat.add_label("MY_LABEL") > textcat.add_label("MY_LABEL")
> ``` > ```
| Name | Type | Description | | Name | Type | Description |
| ------- | ---- | ----------------- | | ----------- | ---- | --------------------------------------------------- |
| `label` | str | The label to add. | | `label` | str | The label to add. |
| **RETURNS** | int | `0` if the label is already present, otherwise `1`. |
## TextCategorizer.use_params {#use_params tag="method, contextmanager"} ## TextCategorizer.use_params {#use_params tag="method, contextmanager"}
@ -304,13 +306,13 @@ Modify the pipe's model, to use the given parameter values.
> >
> ```python > ```python
> textcat = nlp.add_pipe("textcat") > textcat = nlp.add_pipe("textcat")
> with textcat.use_params(): > with textcat.use_params(optimizer.averages):
> textcat.to_disk("/best_model") > textcat.to_disk("/best_model")
> ``` > ```
| Name | Type | Description | | Name | Type | Description |
| -------- | ---- | ---------------------------------------------------------------------------------------------------------- | | -------- | ---- | ----------------------------------------- |
| `params` | - | The parameter values to use in the model. At the end of the context, the original parameters are restored. | | `params` | dict | The parameter values to use in the model. |
## TextCategorizer.to_disk {#to_disk tag="method"} ## TextCategorizer.to_disk {#to_disk tag="method"}

View File

@ -8,4 +8,295 @@ api_string_name: tok2vec
api_trainable: true api_trainable: true
--- ---
TODO: <!-- TODO: intro describing component -->
## Config and implementation {#config}
The default config is defined by the pipeline component factory and describes
how the component should be configured. You can override its settings via the
`config` argument on [`nlp.add_pipe`](/api/language#add_pipe) or in your
[`config.cfg` for training](/usage/training#config). See the
[model architectures](/api/architectures) documentation for details on the
architectures and their arguments and hyperparameters.
> #### Example
>
> ```python
> from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
> config = {"model": DEFAULT_TOK2VEC_MODEL}
> nlp.add_pipe("tok2vec", config=config)
> ```
| Setting | Type | Description | Default |
| ------- | ------------------------------------------ | ----------------- | ----------------------------------------------- |
| `model` | [`Model`](https://thinc.ai/docs/api-model) | The model to use. | [HashEmbedCNN](/api/architectures#HashEmbedCNN) |
```python
https://github.com/explosion/spaCy/blob/develop/spacy/pipeline/tok2vec.py
```
## Tok2Vec.\_\_init\_\_ {#init tag="method"}
> #### Example
>
> ```python
> # Construction via add_pipe with default model
> tok2vec = nlp.add_pipe("tok2vec")
>
> # Construction via add_pipe with custom model
> config = {"model": {"@architectures": "my_tok2vec"}}
> parser = nlp.add_pipe("tok2vec", config=config)
>
> # Construction from class
> from spacy.pipeline import Tok2Vec
> tok2vec = Tok2Vec(nlp.vocab, model)
> ```
Create a new pipeline instance. In your application, you would normally use a
shortcut for this and instantiate the component using its string name and
[`nlp.add_pipe`](/api/language#create_pipe).
| Name | Type | Description |
| ------- | ------------------------------------------ | ------------------------------------------------------------------------------------------- |
| `vocab` | `Vocab` | The shared vocabulary. |
| `model` | [`Model`](https://thinc.ai/docs/api-model) | The Thinc [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. |
| `name` | str | String name of the component instance. Used to add entries to the `losses` during training. |
## Tok2Vec.\_\_call\_\_ {#call tag="method"}
Apply the pipe to one document. The document is modified in place, and returned.
This usually happens under the hood when the `nlp` object is called on a text
and all pipeline components are applied to the `Doc` in order. Both
[`__call__`](/api/tok2vec#call) and [`pipe`](/api/tok2vec#pipe) delegate to the
[`predict`](/api/tok2vec#predict) and
[`set_annotations`](/api/tok2vec#set_annotations) methods.
> #### Example
>
> ```python
> doc = nlp("This is a sentence.")
> tok2vec = nlp.add_pipe("tok2vec")
> # This usually happens under the hood
> processed = tok2vec(doc)
> ```
| Name | Type | Description |
| ----------- | ----- | ------------------------ |
| `doc` | `Doc` | The document to process. |
| **RETURNS** | `Doc` | The processed document. |
## Tok2Vec.pipe {#pipe tag="method"}
Apply the pipe to a stream of documents. This usually happens under the hood
when the `nlp` object is called on a text and all pipeline components are
applied to the `Doc` in order. Both [`__call__`](/api/tok2vec#call) and
[`pipe`](/api/tok2vec#pipe) delegate to the [`predict`](/api/tok2vec#predict)
and [`set_annotations`](/api/tok2vec#set_annotations) methods.
> #### Example
>
> ```python
> tok2vec = nlp.add_pipe("tok2vec")
> for doc in tok2vec.pipe(docs, batch_size=50):
> pass
> ```
| Name | Type | Description |
| -------------- | --------------- | ----------------------------------------------------- |
| `stream` | `Iterable[Doc]` | A stream of documents. |
| _keyword-only_ | | |
| `batch_size` | int | The number of documents to buffer. Defaults to `128`. |
| **YIELDS** | `Doc` | The processed documents in order. |
## Tok2Vec.begin_training {#begin_training tag="method"}
Initialize the pipe for training, using data examples if available. Return an
[`Optimizer`](https://thinc.ai/docs/api-optimizers) object.
> #### Example
>
> ```python
> tok2vec = nlp.add_pipe("tok2vec")
> optimizer = tok2vec.begin_training(pipeline=nlp.pipeline)
> ```
| Name | Type | Description |
| -------------- | --------------------------------------------------- | ---------------------------------------------------------------------------------------------------------- |
| `get_examples` | `Callable[[], Iterable[Example]]` | Optional function that returns gold-standard annotations in the form of [`Example`](/api/example) objects. |
| _keyword-only_ | | |
| `pipeline` | `List[Tuple[str, Callable]]` | Optional list of pipeline components that this component is part of. |
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | An optional optimizer. Will be created via [`create_optimizer`](/api/tok2vec#create_optimizer) if not set. |
| **RETURNS** | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
## Tok2Vec.predict {#predict tag="method"}
Apply the pipeline's model to a batch of docs, without modifying them.
> #### Example
>
> ```python
> tok2vec = nlp.add_pipe("tok2vec")
> scores = tok2vec.predict([doc1, doc2])
> ```
| Name | Type | Description |
| ----------- | --------------- | ----------------------------------------- |
| `docs` | `Iterable[Doc]` | The documents to predict. |
| **RETURNS** | - | The model's prediction for each document. |
## Tok2Vec.set_annotations {#set_annotations tag="method"}
Modify a batch of documents, using pre-computed scores.
> #### Example
>
> ```python
> tok2vec = nlp.add_pipe("tok2vec")
> scores = tok2vec.predict(docs)
> tok2vec.set_annotations(docs, scores)
> ```
| Name | Type | Description |
| -------- | --------------- | ------------------------------------------------- |
| `docs` | `Iterable[Doc]` | The documents to modify. |
| `scores` | - | The scores to set, produced by `Tok2Vec.predict`. |
## Tok2Vec.update {#update tag="method"}
Learn from a batch of documents and gold-standard information, updating the
pipe's model. Delegates to [`predict`](/api/tok2vec#predict).
> #### Example
>
> ```python
> tok2vec = nlp.add_pipe("tok2vec")
> optimizer = nlp.begin_training()
> losses = tok2vec.update(examples, sgd=optimizer)
> ```
| Name | Type | Description |
| ----------------- | --------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------- |
| `examples` | `Iterable[Example]` | A batch of [`Example`](/api/example) objects to learn from. |
| _keyword-only_ | | |
| `drop` | float | The dropout rate. |
| `set_annotations` | bool | Whether or not to update the `Example` objects with the predictions, delegating to [`set_annotations`](/api/tok2vec#set_annotations). |
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
| `losses` | `Dict[str, float]` | Optional record of the loss during training. Updated using the component name as the key. |
| **RETURNS** | `Dict[str, float]` | The updated `losses` dictionary. |
## Tok2Vec.create_optimizer {#create_optimizer tag="method"}
Create an optimizer for the pipeline component.
> #### Example
>
> ```python
> tok2vec = nlp.add_pipe("tok2vec")
> optimizer = tok2vec.create_optimizer()
> ```
| Name | Type | Description |
| ----------- | --------------------------------------------------- | -------------- |
| **RETURNS** | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
## Tok2Vec.use_params {#use_params tag="method, contextmanager"}
Modify the pipe's model, to use the given parameter values. At the end of the
context, the original parameters are restored.
> #### Example
>
> ```python
> tok2vec = nlp.add_pipe("tok2vec")
> with tok2vec.use_params(optimizer.averages):
> tok2vec.to_disk("/best_model")
> ```
| Name | Type | Description |
| -------- | ---- | ----------------------------------------- |
| `params` | dict | The parameter values to use in the model. |
## Tok2Vec.to_disk {#to_disk tag="method"}
Serialize the pipe to disk.
> #### Example
>
> ```python
> tok2vec = nlp.add_pipe("tok2vec")
> tok2vec.to_disk("/path/to/tok2vec")
> ```
| Name | Type | Description |
| --------- | --------------- | --------------------------------------------------------------------------------------------------------------------- |
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
## Tok2Vec.from_disk {#from_disk tag="method"}
Load the pipe from disk. Modifies the object in place and returns it.
> #### Example
>
> ```python
> tok2vec = nlp.add_pipe("tok2vec")
> tok2vec.from_disk("/path/to/tok2vec")
> ```
| Name | Type | Description |
| ----------- | --------------- | -------------------------------------------------------------------------- |
| `path` | str / `Path` | A path to a directory. Paths may be either strings or `Path`-like objects. |
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
| **RETURNS** | `Tok2Vec` | The modified `Tok2Vec` object. |
## Tok2Vec.to_bytes {#to_bytes tag="method"}
> #### Example
>
> ```python
> tok2vec = nlp.add_pipe("tok2vec")
> tok2vec_bytes = tok2vec.to_bytes()
> ```
Serialize the pipe to a bytestring.
| Name | Type | Description |
| ----------- | --------------- | ------------------------------------------------------------------------- |
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
| **RETURNS** | bytes | The serialized form of the `Tok2Vec` object. |
## Tok2Vec.from_bytes {#from_bytes tag="method"}
Load the pipe from a bytestring. Modifies the object in place and returns it.
> #### Example
>
> ```python
> tok2vec_bytes = tok2vec.to_bytes()
> tok2vec = nlp.add_pipe("tok2vec")
> tok2vec.from_bytes(tok2vec_bytes)
> ```
| Name | Type | Description |
| ------------ | --------------- | ------------------------------------------------------------------------- |
| `bytes_data` | bytes | The data to load from. |
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
| **RETURNS** | `Tok2Vec` | The `Tok2Vec` object. |
## Serialization fields {#serialization-fields}
During serialization, spaCy will export several data fields used to restore
different aspects of the object. If needed, you can exclude them from
serialization by passing in the string names via the `exclude` argument.
> #### Example
>
> ```python
> data = tok2vec.to_disk("/path", exclude=["vocab"])
> ```
| Name | Description |
| ------- | -------------------------------------------------------------- |
| `vocab` | The shared [`Vocab`](/api/vocab). |
| `cfg` | The config file. You usually don't want to exclude this. |
| `model` | The binary model data. You usually don't want to exclude this. |