mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Update docstrings, docs and pipe consistency
This commit is contained in:
parent
0094cb0d04
commit
ae4d8a6ffd
|
@ -1100,13 +1100,12 @@ class Language:
|
||||||
return scorer.score(examples)
|
return scorer.score(examples)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def use_params(self, params: dict, **cfg):
|
def use_params(self, params: dict):
|
||||||
"""Replace weights of models in the pipeline with those provided in the
|
"""Replace weights of models in the pipeline with those provided in the
|
||||||
params dictionary. Can be used as a contextmanager, in which case,
|
params dictionary. Can be used as a contextmanager, in which case,
|
||||||
models go back to their original weights after the block.
|
models go back to their original weights after the block.
|
||||||
|
|
||||||
params (dict): A dictionary of parameters keyed by model ID.
|
params (dict): A dictionary of parameters keyed by model ID.
|
||||||
**cfg: Config parameters.
|
|
||||||
|
|
||||||
EXAMPLE:
|
EXAMPLE:
|
||||||
>>> with nlp.use_params(optimizer.averages):
|
>>> with nlp.use_params(optimizer.averages):
|
||||||
|
|
|
@ -128,7 +128,8 @@ class EntityLinker(Pipe):
|
||||||
|
|
||||||
def begin_training(
|
def begin_training(
|
||||||
self,
|
self,
|
||||||
get_examples: Callable = lambda: [],
|
get_examples: Callable[[], Iterable[Example]] = lambda: [],
|
||||||
|
*,
|
||||||
pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None,
|
pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None,
|
||||||
sgd: Optional[Optimizer] = None,
|
sgd: Optional[Optimizer] = None,
|
||||||
) -> Optimizer:
|
) -> Optimizer:
|
||||||
|
@ -273,7 +274,7 @@ class EntityLinker(Pipe):
|
||||||
|
|
||||||
stream (Iterable[Doc]): A stream of documents.
|
stream (Iterable[Doc]): A stream of documents.
|
||||||
batch_size (int): The number of documents to buffer.
|
batch_size (int): The number of documents to buffer.
|
||||||
YIELDS (Doc): PRocessed documents in order.
|
YIELDS (Doc): Processed documents in order.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/entitylinker#pipe
|
DOCS: https://spacy.io/api/entitylinker#pipe
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -97,7 +97,7 @@ class Morphologizer(Tagger):
|
||||||
"""Add a new label to the pipe.
|
"""Add a new label to the pipe.
|
||||||
|
|
||||||
label (str): The label to add.
|
label (str): The label to add.
|
||||||
RETURNS (int): 1
|
RETURNS (int): 0 if label is already present, otherwise 1.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/morphologizer#add_label
|
DOCS: https://spacy.io/api/morphologizer#add_label
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -8,41 +8,51 @@ from ..errors import Errors
|
||||||
from .. import util
|
from .. import util
|
||||||
|
|
||||||
|
|
||||||
def deserialize_config(path):
|
|
||||||
if path.exists():
|
|
||||||
return srsly.read_json(path)
|
|
||||||
else:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
class Pipe:
|
class Pipe:
|
||||||
"""This class is not instantiated directly. Components inherit from it, and
|
"""This class is a base class and not instantiated directly. Trainable
|
||||||
it defines the interface that components should follow to function as
|
pipeline components like the EntityRecognizer or TextCategorizer inherit
|
||||||
components in a spaCy analysis pipeline.
|
from it and it defines the interface that components should follow to
|
||||||
|
function as trainable components in a spaCy pipeline.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/pipe
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name = None
|
name = None
|
||||||
|
|
||||||
def __init__(self, vocab, model, name, **cfg):
|
def __init__(self, vocab, model, name, **cfg):
|
||||||
"""Create a new pipe instance."""
|
"""Initialize a pipeline component.
|
||||||
|
|
||||||
|
vocab (Vocab): The shared vocabulary.
|
||||||
|
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
||||||
|
name (str): The component instance name.
|
||||||
|
**cfg: Additonal settings and config parameters.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/pipe#init
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __call__(self, Doc doc):
|
def __call__(self, Doc doc):
|
||||||
"""Apply the pipe to one document. The document is
|
"""Add context-sensitive embeddings to the Doc.tensor attribute.
|
||||||
modified in-place, and returned.
|
|
||||||
|
|
||||||
Both __call__ and pipe should delegate to the `predict()`
|
docs (Doc): The Doc to preocess.
|
||||||
and `set_annotations()` methods.
|
RETURNS (Doc): The processed Doc.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/pipe#call
|
||||||
"""
|
"""
|
||||||
scores = self.predict([doc])
|
scores = self.predict([doc])
|
||||||
self.set_annotations([doc], scores)
|
self.set_annotations([doc], scores)
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
def pipe(self, stream, batch_size=128):
|
def pipe(self, stream, *, batch_size=128):
|
||||||
"""Apply the pipe to a stream of documents.
|
"""Apply the pipe to a stream of documents. This usually happens under
|
||||||
|
the hood when the nlp object is called on a text and all components are
|
||||||
|
applied to the Doc.
|
||||||
|
|
||||||
Both __call__ and pipe should delegate to the `predict()`
|
stream (Iterable[Doc]): A stream of documents.
|
||||||
and `set_annotations()` methods.
|
batch_size (int): The number of documents to buffer.
|
||||||
|
YIELDS (Doc): Processed documents in order.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/pipe#pipe
|
||||||
"""
|
"""
|
||||||
for docs in util.minibatch(stream, size=batch_size):
|
for docs in util.minibatch(stream, size=batch_size):
|
||||||
scores = self.predict(docs)
|
scores = self.predict(docs)
|
||||||
|
@ -50,38 +60,90 @@ class Pipe:
|
||||||
yield from docs
|
yield from docs
|
||||||
|
|
||||||
def predict(self, docs):
|
def predict(self, docs):
|
||||||
"""Apply the pipeline's model to a batch of docs, without
|
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||||
modifying them.
|
Returns a single tensor for a batch of documents.
|
||||||
|
|
||||||
|
docs (Iterable[Doc]): The documents to predict.
|
||||||
|
RETURNS: Vector representations for each token in the documents.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/pipe#predict
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def set_annotations(self, docs, scores):
|
def set_annotations(self, docs, scores):
|
||||||
"""Modify a batch of documents, using pre-computed scores."""
|
"""Modify a batch of documents, using pre-computed scores.
|
||||||
|
|
||||||
|
docs (Iterable[Doc]): The documents to modify.
|
||||||
|
tokvecses: The tensors to set, produced by Pipe.predict.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/pipe#predict
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def rehearse(self, examples, sgd=None, losses=None, **config):
|
def rehearse(self, examples, *, sgd=None, losses=None, **config):
|
||||||
|
"""Perform a "rehearsal" update from a batch of data. Rehearsal updates
|
||||||
|
teach the current model to make predictions similar to an initial model,
|
||||||
|
to try to address the "catastrophic forgetting" problem. This feature is
|
||||||
|
experimental.
|
||||||
|
|
||||||
|
examples (Iterable[Example]): A batch of Example objects.
|
||||||
|
drop (float): The dropout rate.
|
||||||
|
sgd (thinc.api.Optimizer): The optimizer.
|
||||||
|
losses (Dict[str, float]): Optional record of the loss during training.
|
||||||
|
Updated using the component name as the key.
|
||||||
|
RETURNS (Dict[str, float]): The updated losses dictionary.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/pipe#rehearse
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_loss(self, examples, scores):
|
def get_loss(self, examples, scores):
|
||||||
"""Find the loss and gradient of loss for the batch of
|
"""Find the loss and gradient of loss for the batch of documents and
|
||||||
examples (with embedded docs) and their predicted scores."""
|
their predicted scores.
|
||||||
|
|
||||||
|
examples (Iterable[Examples]): The batch of examples.
|
||||||
|
scores: Scores representing the model's predictions.
|
||||||
|
RETUTNRS (Tuple[float, float]): The loss and the gradient.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/pipe#get_loss
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def add_label(self, label):
|
def add_label(self, label):
|
||||||
"""Add an output label, to be predicted by the model.
|
"""Add an output label, to be predicted by the model. It's possible to
|
||||||
|
extend pretrained models with new labels, but care should be taken to
|
||||||
|
avoid the "catastrophic forgetting" problem.
|
||||||
|
|
||||||
It's possible to extend pretrained models with new labels,
|
label (str): The label to add.
|
||||||
but care should be taken to avoid the "catastrophic forgetting"
|
RETURNS (int): 0 if label is already present, otherwise 1.
|
||||||
problem.
|
|
||||||
|
DOCS: https://spacy.io/api/pipe#add_label
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def create_optimizer(self):
|
def create_optimizer(self):
|
||||||
|
"""Create an optimizer for the pipeline component.
|
||||||
|
|
||||||
|
RETURNS (thinc.api.Optimizer): The optimizer.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/pipe#create_optimizer
|
||||||
|
"""
|
||||||
return create_default_optimizer()
|
return create_default_optimizer()
|
||||||
|
|
||||||
def begin_training(self, get_examples=lambda: [], *, pipeline=None, sgd=None):
|
def begin_training(self, get_examples=lambda: [], *, pipeline=None, sgd=None):
|
||||||
"""Initialize the pipe for training, using data exampes if available.
|
"""Initialize the pipe for training, using data examples if available.
|
||||||
If no model has been initialized yet, the model is added."""
|
|
||||||
|
get_examples (Callable[[], Iterable[Example]]): Optional function that
|
||||||
|
returns gold-standard Example objects.
|
||||||
|
pipeline (List[Tuple[str, Callable]]): Optional list of pipeline
|
||||||
|
components that this component is part of. Corresponds to
|
||||||
|
nlp.pipeline.
|
||||||
|
sgd (thinc.api.Optimizer): Optional optimizer. Will be created with
|
||||||
|
create_optimizer if it doesn't exist.
|
||||||
|
RETURNS (thinc.api.Optimizer): The optimizer.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/pipe#begin_training
|
||||||
|
"""
|
||||||
self.model.initialize()
|
self.model.initialize()
|
||||||
if hasattr(self, "vocab"):
|
if hasattr(self, "vocab"):
|
||||||
link_vectors_to_models(self.vocab)
|
link_vectors_to_models(self.vocab)
|
||||||
|
@ -90,6 +152,7 @@ class Pipe:
|
||||||
return sgd
|
return sgd
|
||||||
|
|
||||||
def set_output(self, nO):
|
def set_output(self, nO):
|
||||||
|
# TODO: document this across components?
|
||||||
if self.model.has_dim("nO") is not False:
|
if self.model.has_dim("nO") is not False:
|
||||||
self.model.set_dim("nO", nO)
|
self.model.set_dim("nO", nO)
|
||||||
if self.model.has_ref("output_layer"):
|
if self.model.has_ref("output_layer"):
|
||||||
|
@ -99,6 +162,7 @@ class Pipe:
|
||||||
"""Get non-zero gradients of the model's parameters, as a dictionary
|
"""Get non-zero gradients of the model's parameters, as a dictionary
|
||||||
keyed by the parameter ID. The values are (weights, gradients) tuples.
|
keyed by the parameter ID. The values are (weights, gradients) tuples.
|
||||||
"""
|
"""
|
||||||
|
# TODO: How is this used?
|
||||||
gradients = {}
|
gradients = {}
|
||||||
queue = [self.model]
|
queue = [self.model]
|
||||||
seen = set()
|
seen = set()
|
||||||
|
@ -113,18 +177,33 @@ class Pipe:
|
||||||
return gradients
|
return gradients
|
||||||
|
|
||||||
def use_params(self, params):
|
def use_params(self, params):
|
||||||
"""Modify the pipe's model, to use the given parameter values."""
|
"""Modify the pipe's model, to use the given parameter values. At the
|
||||||
|
end of the context, the original parameters are restored.
|
||||||
|
|
||||||
|
params (dict): The parameter values to use in the model.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/pipe#use_params
|
||||||
|
"""
|
||||||
with self.model.use_params(params):
|
with self.model.use_params(params):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
def score(self, examples, **kwargs):
|
def score(self, examples, **kwargs):
|
||||||
|
"""Score a batch of examples.
|
||||||
|
|
||||||
|
examples (Iterable[Example]): The examples to score.
|
||||||
|
RETURNS (Dict[str, Any]): The scores.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/pipe#score
|
||||||
|
"""
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def to_bytes(self, exclude=tuple()):
|
def to_bytes(self, exclude=tuple()):
|
||||||
"""Serialize the pipe to a bytestring.
|
"""Serialize the pipe to a bytestring.
|
||||||
|
|
||||||
exclude (list): String names of serialization fields to exclude.
|
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||||
RETURNS (bytes): The serialized object.
|
RETURNS (bytes): The serialized object.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/pipe#to_bytes
|
||||||
"""
|
"""
|
||||||
serialize = {}
|
serialize = {}
|
||||||
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
|
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
|
||||||
|
@ -134,7 +213,13 @@ class Pipe:
|
||||||
return util.to_bytes(serialize, exclude)
|
return util.to_bytes(serialize, exclude)
|
||||||
|
|
||||||
def from_bytes(self, bytes_data, exclude=tuple()):
|
def from_bytes(self, bytes_data, exclude=tuple()):
|
||||||
"""Load the pipe from a bytestring."""
|
"""Load the pipe from a bytestring.
|
||||||
|
|
||||||
|
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||||
|
RETURNS (Pipe): The loaded object.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/pipe#from_bytes
|
||||||
|
"""
|
||||||
|
|
||||||
def load_model(b):
|
def load_model(b):
|
||||||
try:
|
try:
|
||||||
|
@ -151,7 +236,13 @@ class Pipe:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def to_disk(self, path, exclude=tuple()):
|
def to_disk(self, path, exclude=tuple()):
|
||||||
"""Serialize the pipe to disk."""
|
"""Serialize the pipe to disk.
|
||||||
|
|
||||||
|
path (str / Path): Path to a directory.
|
||||||
|
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/pipe#to_disk
|
||||||
|
"""
|
||||||
serialize = {}
|
serialize = {}
|
||||||
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
|
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
|
||||||
serialize["vocab"] = lambda p: self.vocab.to_disk(p)
|
serialize["vocab"] = lambda p: self.vocab.to_disk(p)
|
||||||
|
@ -159,7 +250,14 @@ class Pipe:
|
||||||
util.to_disk(path, serialize, exclude)
|
util.to_disk(path, serialize, exclude)
|
||||||
|
|
||||||
def from_disk(self, path, exclude=tuple()):
|
def from_disk(self, path, exclude=tuple()):
|
||||||
"""Load the pipe from disk."""
|
"""Load the pipe from disk.
|
||||||
|
|
||||||
|
path (str / Path): Path to a directory.
|
||||||
|
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||||
|
RETURNS (Pipe): The loaded object.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/pipe#from_disk
|
||||||
|
"""
|
||||||
|
|
||||||
def load_model(p):
|
def load_model(p):
|
||||||
try:
|
try:
|
||||||
|
@ -173,3 +271,10 @@ class Pipe:
|
||||||
deserialize["model"] = load_model
|
deserialize["model"] = load_model
|
||||||
util.from_disk(path, deserialize, exclude)
|
util.from_disk(path, deserialize, exclude)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_config(path):
|
||||||
|
if path.exists():
|
||||||
|
return srsly.read_json(path)
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
|
@ -329,7 +329,7 @@ class Tagger(Pipe):
|
||||||
label (str): The label to add.
|
label (str): The label to add.
|
||||||
values (Dict[int, str]): Optional values to map to the label, e.g. a
|
values (Dict[int, str]): Optional values to map to the label, e.g. a
|
||||||
tag map dictionary.
|
tag map dictionary.
|
||||||
RETURNS (int): 1
|
RETURNS (int): 0 if label is already present, otherwise 1.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/tagger#add_label
|
DOCS: https://spacy.io/api/tagger#add_label
|
||||||
"""
|
"""
|
||||||
|
@ -355,10 +355,6 @@ class Tagger(Pipe):
|
||||||
self.vocab.morphology.load_tag_map(tag_map)
|
self.vocab.morphology.load_tag_map(tag_map)
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
def use_params(self, params):
|
|
||||||
with self.model.use_params(params):
|
|
||||||
yield
|
|
||||||
|
|
||||||
def score(self, examples, **kwargs):
|
def score(self, examples, **kwargs):
|
||||||
"""Score a batch of examples.
|
"""Score a batch of examples.
|
||||||
|
|
||||||
|
|
|
@ -56,7 +56,17 @@ dropout = null
|
||||||
"textcat",
|
"textcat",
|
||||||
assigns=["doc.cats"],
|
assigns=["doc.cats"],
|
||||||
default_config={"labels": [], "model": DEFAULT_TEXTCAT_MODEL},
|
default_config={"labels": [], "model": DEFAULT_TEXTCAT_MODEL},
|
||||||
scores=["cats_score", "cats_score_desc", "cats_p", "cats_r", "cats_f", "cats_macro_f", "cats_macro_auc", "cats_f_per_type", "cats_macro_auc_per_type"],
|
scores=[
|
||||||
|
"cats_score",
|
||||||
|
"cats_score_desc",
|
||||||
|
"cats_p",
|
||||||
|
"cats_r",
|
||||||
|
"cats_f",
|
||||||
|
"cats_macro_f",
|
||||||
|
"cats_macro_auc",
|
||||||
|
"cats_f_per_type",
|
||||||
|
"cats_macro_auc_per_type",
|
||||||
|
],
|
||||||
default_score_weights={"cats_score": 1.0},
|
default_score_weights={"cats_score": 1.0},
|
||||||
)
|
)
|
||||||
def make_textcat(
|
def make_textcat(
|
||||||
|
@ -120,7 +130,7 @@ class TextCategorizer(Pipe):
|
||||||
|
|
||||||
stream (Iterable[Doc]): A stream of documents.
|
stream (Iterable[Doc]): A stream of documents.
|
||||||
batch_size (int): The number of documents to buffer.
|
batch_size (int): The number of documents to buffer.
|
||||||
YIELDS (Doc): PRocessed documents in order.
|
YIELDS (Doc): Processed documents in order.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/textcategorizer#pipe
|
DOCS: https://spacy.io/api/textcategorizer#pipe
|
||||||
"""
|
"""
|
||||||
|
@ -288,7 +298,7 @@ class TextCategorizer(Pipe):
|
||||||
"""Add a new label to the pipe.
|
"""Add a new label to the pipe.
|
||||||
|
|
||||||
label (str): The label to add.
|
label (str): The label to add.
|
||||||
RETURNS (int): 1.
|
RETURNS (int): 0 if label is already present, otherwise 1.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/textcategorizer#add_label
|
DOCS: https://spacy.io/api/textcategorizer#add_label
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -34,10 +34,13 @@ def make_tok2vec(nlp: Language, name: str, model: Model) -> "Tok2Vec":
|
||||||
|
|
||||||
class Tok2Vec(Pipe):
|
class Tok2Vec(Pipe):
|
||||||
def __init__(self, vocab: Vocab, model: Model, name: str = "tok2vec") -> None:
|
def __init__(self, vocab: Vocab, model: Model, name: str = "tok2vec") -> None:
|
||||||
"""Construct a new statistical model. Weights are not allocated on
|
"""Initialize a tok2vec component.
|
||||||
initialisation.
|
|
||||||
vocab (Vocab): A `Vocab` instance. The model must share the same `Vocab`
|
vocab (Vocab): The shared vocabulary.
|
||||||
instance with the `Doc` objects it will process.
|
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
||||||
|
name (str): The component instance name.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/tok2vec#init
|
||||||
"""
|
"""
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.model = model
|
self.model = model
|
||||||
|
@ -57,20 +60,27 @@ class Tok2Vec(Pipe):
|
||||||
self.add_listener(node)
|
self.add_listener(node)
|
||||||
|
|
||||||
def __call__(self, doc: Doc) -> Doc:
|
def __call__(self, doc: Doc) -> Doc:
|
||||||
"""Add context-sensitive vectors to a `Doc`, e.g. from a CNN or LSTM
|
"""Add context-sensitive embeddings to the Doc.tensor attribute.
|
||||||
model. Vectors are set to the `Doc.tensor` attribute.
|
|
||||||
docs (Doc or iterable): One or more documents to add vectors to.
|
docs (Doc): The Doc to preocess.
|
||||||
RETURNS (dict or None): Intermediate computations.
|
RETURNS (Doc): The processed Doc.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/tok2vec#call
|
||||||
"""
|
"""
|
||||||
tokvecses = self.predict([doc])
|
tokvecses = self.predict([doc])
|
||||||
self.set_annotations([doc], tokvecses)
|
self.set_annotations([doc], tokvecses)
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
def pipe(self, stream: Iterator[Doc], batch_size: int = 128) -> Iterator[Doc]:
|
def pipe(self, stream: Iterator[Doc], *, batch_size: int = 128) -> Iterator[Doc]:
|
||||||
"""Process `Doc` objects as a stream.
|
"""Apply the pipe to a stream of documents. This usually happens under
|
||||||
stream (iterator): A sequence of `Doc` objects to process.
|
the hood when the nlp object is called on a text and all components are
|
||||||
batch_size (int): Number of `Doc` objects to group.
|
applied to the Doc.
|
||||||
YIELDS (iterator): A sequence of `Doc` objects, in order of input.
|
|
||||||
|
stream (Iterable[Doc]): A stream of documents.
|
||||||
|
batch_size (int): The number of documents to buffer.
|
||||||
|
YIELDS (Doc): Processed documents in order.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/tok2vec#pipe
|
||||||
"""
|
"""
|
||||||
for docs in minibatch(stream, batch_size):
|
for docs in minibatch(stream, batch_size):
|
||||||
docs = list(docs)
|
docs = list(docs)
|
||||||
|
@ -78,10 +88,14 @@ class Tok2Vec(Pipe):
|
||||||
self.set_annotations(docs, tokvecses)
|
self.set_annotations(docs, tokvecses)
|
||||||
yield from docs
|
yield from docs
|
||||||
|
|
||||||
def predict(self, docs: Sequence[Doc]):
|
def predict(self, docs: Iterable[Doc]):
|
||||||
"""Return a single tensor for a batch of documents.
|
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||||
docs (iterable): A sequence of `Doc` objects.
|
Returns a single tensor for a batch of documents.
|
||||||
RETURNS (object): Vector representations for each token in the documents.
|
|
||||||
|
docs (Iterable[Doc]): The documents to predict.
|
||||||
|
RETURNS: Vector representations for each token in the documents.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/tok2vec#predict
|
||||||
"""
|
"""
|
||||||
tokvecs = self.model.predict(docs)
|
tokvecs = self.model.predict(docs)
|
||||||
batch_id = Tok2VecListener.get_batch_id(docs)
|
batch_id = Tok2VecListener.get_batch_id(docs)
|
||||||
|
@ -90,9 +104,12 @@ class Tok2Vec(Pipe):
|
||||||
return tokvecs
|
return tokvecs
|
||||||
|
|
||||||
def set_annotations(self, docs: Sequence[Doc], tokvecses) -> None:
|
def set_annotations(self, docs: Sequence[Doc], tokvecses) -> None:
|
||||||
"""Set the tensor attribute for a batch of documents.
|
"""Modify a batch of documents, using pre-computed scores.
|
||||||
docs (iterable): A sequence of `Doc` objects.
|
|
||||||
tokvecs (object): Vector representation for each token in the documents.
|
docs (Iterable[Doc]): The documents to modify.
|
||||||
|
tokvecses: The tensors to set, produced by Tok2Vec.predict.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/tok2vec#predict
|
||||||
"""
|
"""
|
||||||
for doc, tokvecs in zip(docs, tokvecses):
|
for doc, tokvecs in zip(docs, tokvecses):
|
||||||
assert tokvecs.shape[0] == len(doc)
|
assert tokvecs.shape[0] == len(doc)
|
||||||
|
@ -107,13 +124,19 @@ class Tok2Vec(Pipe):
|
||||||
losses: Optional[Dict[str, float]] = None,
|
losses: Optional[Dict[str, float]] = None,
|
||||||
set_annotations: bool = False,
|
set_annotations: bool = False,
|
||||||
):
|
):
|
||||||
"""Update the model.
|
"""Learn from a batch of documents and gold-standard information,
|
||||||
examples (Iterable[Example]): A batch of examples
|
updating the pipe's model.
|
||||||
drop (float): The droput rate.
|
|
||||||
sgd (Optimizer): An optimizer.
|
examples (Iterable[Example]): A batch of Example objects.
|
||||||
losses (Dict[str, float]): Dictionary to update with the loss, keyed by component.
|
drop (float): The dropout rate.
|
||||||
set_annotations (bool): whether or not to update the examples with the predictions
|
set_annotations (bool): Whether or not to update the Example objects
|
||||||
RETURNS (Dict[str, float]): The updated losses dictionary
|
with the predictions.
|
||||||
|
sgd (thinc.api.Optimizer): The optimizer.
|
||||||
|
losses (Dict[str, float]): Optional record of the loss during training.
|
||||||
|
Updated using the component name as the key.
|
||||||
|
RETURNS (Dict[str, float]): The updated losses dictionary.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/tok2vec#update
|
||||||
"""
|
"""
|
||||||
if losses is None:
|
if losses is None:
|
||||||
losses = {}
|
losses = {}
|
||||||
|
@ -122,7 +145,6 @@ class Tok2Vec(Pipe):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
tokvecs, bp_tokvecs = self.model.begin_update(docs)
|
tokvecs, bp_tokvecs = self.model.begin_update(docs)
|
||||||
|
|
||||||
d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
|
d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
|
||||||
losses.setdefault(self.name, 0.0)
|
losses.setdefault(self.name, 0.0)
|
||||||
|
|
||||||
|
@ -156,14 +178,23 @@ class Tok2Vec(Pipe):
|
||||||
|
|
||||||
def begin_training(
|
def begin_training(
|
||||||
self,
|
self,
|
||||||
get_examples: Callable = lambda: [],
|
get_examples: Callable[[], Iterable[Example]] = lambda: [],
|
||||||
|
*,
|
||||||
pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None,
|
pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None,
|
||||||
sgd: Optional[Optimizer] = None,
|
sgd: Optional[Optimizer] = None,
|
||||||
):
|
):
|
||||||
"""Allocate models and pre-process training data
|
"""Initialize the pipe for training, using data examples if available.
|
||||||
|
|
||||||
get_examples (function): Function returning example training data.
|
get_examples (Callable[[], Iterable[Example]]): Optional function that
|
||||||
pipeline (list): The pipeline the model is part of.
|
returns gold-standard Example objects.
|
||||||
|
pipeline (List[Tuple[str, Callable]]): Optional list of pipeline
|
||||||
|
components that this component is part of. Corresponds to
|
||||||
|
nlp.pipeline.
|
||||||
|
sgd (thinc.api.Optimizer): Optional optimizer. Will be created with
|
||||||
|
create_optimizer if it doesn't exist.
|
||||||
|
RETURNS (thinc.api.Optimizer): The optimizer.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/tok2vec#begin_training
|
||||||
"""
|
"""
|
||||||
docs = [Doc(Vocab(), words=["hello"])]
|
docs = [Doc(Vocab(), words=["hello"])]
|
||||||
self.model.initialize(X=docs)
|
self.model.initialize(X=docs)
|
||||||
|
|
|
@ -123,6 +123,8 @@ cdef class Parser:
|
||||||
resized = True
|
resized = True
|
||||||
if resized:
|
if resized:
|
||||||
self._resize()
|
self._resize()
|
||||||
|
return 1
|
||||||
|
return 0
|
||||||
|
|
||||||
def _resize(self):
|
def _resize(self):
|
||||||
self.model.attrs["resize_output"](self.model, self.moves.n_moves)
|
self.model.attrs["resize_output"](self.model, self.moves.n_moves)
|
||||||
|
|
|
@ -1182,6 +1182,7 @@ VECTORS_KEY = "spacy_pretrained_vectors"
|
||||||
|
|
||||||
|
|
||||||
def create_default_optimizer() -> Optimizer:
|
def create_default_optimizer() -> Optimizer:
|
||||||
|
# TODO: Do we still want to allow env_opt?
|
||||||
learn_rate = env_opt("learn_rate", 0.001)
|
learn_rate = env_opt("learn_rate", 0.001)
|
||||||
beta1 = env_opt("optimizer_B1", 0.9)
|
beta1 = env_opt("optimizer_B1", 0.9)
|
||||||
beta2 = env_opt("optimizer_B2", 0.999)
|
beta2 = env_opt("optimizer_B2", 0.999)
|
||||||
|
|
|
@ -248,19 +248,20 @@ component.
|
||||||
|
|
||||||
## DependencyParser.use_params {#use_params tag="method, contextmanager"}
|
## DependencyParser.use_params {#use_params tag="method, contextmanager"}
|
||||||
|
|
||||||
Modify the pipe's model, to use the given parameter values.
|
Modify the pipe's model, to use the given parameter values. At the end of the
|
||||||
|
context, the original parameters are restored.
|
||||||
|
|
||||||
> #### Example
|
> #### Example
|
||||||
>
|
>
|
||||||
> ```python
|
> ```python
|
||||||
> parser = DependencyParser(nlp.vocab)
|
> parser = DependencyParser(nlp.vocab)
|
||||||
> with parser.use_params():
|
> with parser.use_params(optimizer.averages):
|
||||||
> parser.to_disk("/best_model")
|
> parser.to_disk("/best_model")
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Type | Description |
|
| Name | Type | Description |
|
||||||
| -------- | ---- | ---------------------------------------------------------------------------------------------------------- |
|
| -------- | ---- | ----------------------------------------- |
|
||||||
| `params` | - | The parameter values to use in the model. At the end of the context, the original parameters are restored. |
|
| `params` | dict | The parameter values to use in the model. |
|
||||||
|
|
||||||
## DependencyParser.add_label {#add_label tag="method"}
|
## DependencyParser.add_label {#add_label tag="method"}
|
||||||
|
|
||||||
|
@ -273,9 +274,10 @@ Add a new label to the pipe.
|
||||||
> parser.add_label("MY_LABEL")
|
> parser.add_label("MY_LABEL")
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Type | Description |
|
| Name | Type | Description |
|
||||||
| ------- | ---- | ----------------- |
|
| ----------- | ---- | --------------------------------------------------- |
|
||||||
| `label` | str | The label to add. |
|
| `label` | str | The label to add. |
|
||||||
|
| **RETURNS** | int | `0` if the label is already present, otherwise `1`. |
|
||||||
|
|
||||||
## DependencyParser.to_disk {#to_disk tag="method"}
|
## DependencyParser.to_disk {#to_disk tag="method"}
|
||||||
|
|
||||||
|
|
|
@ -239,7 +239,8 @@ Create an optimizer for the pipeline component.
|
||||||
|
|
||||||
## EntityLinker.use_params {#use_params tag="method, contextmanager"}
|
## EntityLinker.use_params {#use_params tag="method, contextmanager"}
|
||||||
|
|
||||||
Modify the pipe's EL model, to use the given parameter values.
|
Modify the pipe's model, to use the given parameter values. At the end of the
|
||||||
|
context, the original parameters are restored.
|
||||||
|
|
||||||
> #### Example
|
> #### Example
|
||||||
>
|
>
|
||||||
|
@ -249,9 +250,9 @@ Modify the pipe's EL model, to use the given parameter values.
|
||||||
> entity_linker.to_disk("/best_model")
|
> entity_linker.to_disk("/best_model")
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Type | Description |
|
| Name | Type | Description |
|
||||||
| -------- | ---- | ---------------------------------------------------------------------------------------------------------- |
|
| -------- | ---- | ----------------------------------------- |
|
||||||
| `params` | dict | The parameter values to use in the model. At the end of the context, the original parameters are restored. |
|
| `params` | dict | The parameter values to use in the model. |
|
||||||
|
|
||||||
## EntityLinker.to_disk {#to_disk tag="method"}
|
## EntityLinker.to_disk {#to_disk tag="method"}
|
||||||
|
|
||||||
|
|
|
@ -247,7 +247,8 @@ Create an optimizer for the pipeline component.
|
||||||
|
|
||||||
## EntityRecognizer.use_params {#use_params tag="method, contextmanager"}
|
## EntityRecognizer.use_params {#use_params tag="method, contextmanager"}
|
||||||
|
|
||||||
Modify the pipe's model, to use the given parameter values.
|
Modify the pipe's model, to use the given parameter values. At the end of the
|
||||||
|
context, the original parameters are restored.
|
||||||
|
|
||||||
> #### Example
|
> #### Example
|
||||||
>
|
>
|
||||||
|
@ -257,9 +258,9 @@ Modify the pipe's model, to use the given parameter values.
|
||||||
> ner.to_disk("/best_model")
|
> ner.to_disk("/best_model")
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Type | Description |
|
| Name | Type | Description |
|
||||||
| -------- | ---- | ---------------------------------------------------------------------------------------------------------- |
|
| -------- | ---- | ----------------------------------------- |
|
||||||
| `params` | dict | The parameter values to use in the model. At the end of the context, the original parameters are restored. |
|
| `params` | dict | The parameter values to use in the model. |
|
||||||
|
|
||||||
## EntityRecognizer.add_label {#add_label tag="method"}
|
## EntityRecognizer.add_label {#add_label tag="method"}
|
||||||
|
|
||||||
|
@ -272,9 +273,10 @@ Add a new label to the pipe.
|
||||||
> ner.add_label("MY_LABEL")
|
> ner.add_label("MY_LABEL")
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Type | Description |
|
| Name | Type | Description |
|
||||||
| ------- | ---- | ----------------- |
|
| ----------- | ---- | --------------------------------------------------- |
|
||||||
| `label` | str | The label to add. |
|
| `label` | str | The label to add. |
|
||||||
|
| **RETURNS** | int | `0` if the label is already present, otherwise `1`. |
|
||||||
|
|
||||||
## EntityRecognizer.to_disk {#to_disk tag="method"}
|
## EntityRecognizer.to_disk {#to_disk tag="method"}
|
||||||
|
|
||||||
|
|
|
@ -271,7 +271,6 @@ their original weights after the block.
|
||||||
| Name | Type | Description |
|
| Name | Type | Description |
|
||||||
| -------- | ---- | --------------------------------------------- |
|
| -------- | ---- | --------------------------------------------- |
|
||||||
| `params` | dict | A dictionary of parameters keyed by model ID. |
|
| `params` | dict | A dictionary of parameters keyed by model ID. |
|
||||||
| `**cfg` | - | Config parameters. |
|
|
||||||
|
|
||||||
## Language.create_pipe {#create_pipe tag="method" new="2"}
|
## Language.create_pipe {#create_pipe tag="method" new="2"}
|
||||||
|
|
||||||
|
|
|
@ -233,19 +233,20 @@ Create an optimizer for the pipeline component.
|
||||||
|
|
||||||
## Morphologizer.use_params {#use_params tag="method, contextmanager"}
|
## Morphologizer.use_params {#use_params tag="method, contextmanager"}
|
||||||
|
|
||||||
Modify the pipe's model, to use the given parameter values.
|
Modify the pipe's model, to use the given parameter values. At the end of the
|
||||||
|
context, the original parameters are restored.
|
||||||
|
|
||||||
> #### Example
|
> #### Example
|
||||||
>
|
>
|
||||||
> ```python
|
> ```python
|
||||||
> morphologizer = nlp.add_pipe("morphologizer")
|
> morphologizer = nlp.add_pipe("morphologizer")
|
||||||
> with morphologizer.use_params():
|
> with morphologizer.use_params(optimizer.averages):
|
||||||
> morphologizer.to_disk("/best_model")
|
> morphologizer.to_disk("/best_model")
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Type | Description |
|
| Name | Type | Description |
|
||||||
| -------- | ---- | ---------------------------------------------------------------------------------------------------------- |
|
| -------- | ---- | ----------------------------------------- |
|
||||||
| `params` | - | The parameter values to use in the model. At the end of the context, the original parameters are restored. |
|
| `params` | dict | The parameter values to use in the model. |
|
||||||
|
|
||||||
## Morphologizer.add_label {#add_label tag="method"}
|
## Morphologizer.add_label {#add_label tag="method"}
|
||||||
|
|
||||||
|
@ -259,9 +260,10 @@ both `pos` and `morph`, the label should include the UPOS as the feature `POS`.
|
||||||
> morphologizer.add_label("Mood=Ind|POS=VERB|Tense=Past|VerbForm=Fin")
|
> morphologizer.add_label("Mood=Ind|POS=VERB|Tense=Past|VerbForm=Fin")
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Type | Description |
|
| Name | Type | Description |
|
||||||
| ------- | ---- | ----------------- |
|
| ----------- | ---- | --------------------------------------------------- |
|
||||||
| `label` | str | The label to add. |
|
| `label` | str | The label to add. |
|
||||||
|
| **RETURNS** | int | `0` if the label is already present, otherwise `1`. |
|
||||||
|
|
||||||
## Morphologizer.to_disk {#to_disk tag="method"}
|
## Morphologizer.to_disk {#to_disk tag="method"}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,381 @@
|
||||||
---
|
---
|
||||||
title: Pipe
|
title: Pipe
|
||||||
tag: class
|
tag: class
|
||||||
|
teaser: Base class for trainable pipeline components
|
||||||
---
|
---
|
||||||
|
|
||||||
TODO: write
|
This class is a base class and **not instantiated directly**. Trainable pipeline
|
||||||
|
components like the [`EntityRecognizer`](/api/entityrecognizer) or
|
||||||
|
[`TextCategorizer`](/api/textcategorizer) inherit from it and it defines the
|
||||||
|
interface that components should follow to function as trainable components in a
|
||||||
|
spaCy pipeline.
|
||||||
|
|
||||||
|
```python
|
||||||
|
https://github.com/explosion/spaCy/blob/develop/spacy/pipeline/pipe.pyx
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pipe.\_\_init\_\_ {#init tag="method"}
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> from spacy.pipeline import Pipe
|
||||||
|
> from spacy.language import Language
|
||||||
|
>
|
||||||
|
> class CustomPipe(Pipe):
|
||||||
|
> ...
|
||||||
|
>
|
||||||
|
> @Language.factory("your_custom_pipe", default_config={"model": MODEL})
|
||||||
|
> def make_custom_pipe(nlp, name, model):
|
||||||
|
> return CustomPipe(nlp.vocab, model, name)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
Create a new pipeline instance. In your application, you would normally use a
|
||||||
|
shortcut for this and instantiate the component using its string name and
|
||||||
|
[`nlp.add_pipe`](/api/language#create_pipe).
|
||||||
|
|
||||||
|
<Infobox variant="danger">
|
||||||
|
|
||||||
|
This method needs to be overwritten with your own custom `__init__` method.
|
||||||
|
|
||||||
|
</Infobox>
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| ------- | ------------------------------------------ | ------------------------------------------------------------------------------------------- |
|
||||||
|
| `vocab` | `Vocab` | The shared vocabulary. |
|
||||||
|
| `model` | [`Model`](https://thinc.ai/docs/api-model) | The Thinc [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. |
|
||||||
|
| `name` | str | String name of the component instance. Used to add entries to the `losses` during training. |
|
||||||
|
| `**cfg` | | Additional config parameters and settings. |
|
||||||
|
|
||||||
|
## Pipe.\_\_call\_\_ {#call tag="method"}
|
||||||
|
|
||||||
|
Apply the pipe to one document. The document is modified in place, and returned.
|
||||||
|
This usually happens under the hood when the `nlp` object is called on a text
|
||||||
|
and all pipeline components are applied to the `Doc` in order. Both
|
||||||
|
[`__call__`](/api/pipe#call) and [`pipe`](/api/pipe#pipe) delegate to the
|
||||||
|
[`predict`](/api/pipe#predict) and
|
||||||
|
[`set_annotations`](/api/pipe#set_annotations) methods.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> doc = nlp("This is a sentence.")
|
||||||
|
> pipe = nlp.add_pipe("your_custom_pipe")
|
||||||
|
> # This usually happens under the hood
|
||||||
|
> processed = pipe(doc)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| ----------- | ----- | ------------------------ |
|
||||||
|
| `doc` | `Doc` | The document to process. |
|
||||||
|
| **RETURNS** | `Doc` | The processed document. |
|
||||||
|
|
||||||
|
## Pipe.pipe {#pipe tag="method"}
|
||||||
|
|
||||||
|
Apply the pipe to a stream of documents. This usually happens under the hood
|
||||||
|
when the `nlp` object is called on a text and all pipeline components are
|
||||||
|
applied to the `Doc` in order. Both [`__call__`](/api/pipe#call) and
|
||||||
|
[`pipe`](/api/pipe#pipe) delegate to the [`predict`](/api/pipe#predict) and
|
||||||
|
[`set_annotations`](/api/pipe#set_annotations) methods.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> pipe = nlp.add_pipe("your_custom_pipe")
|
||||||
|
> for doc in pipe.pipe(docs, batch_size=50):
|
||||||
|
> pass
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| -------------- | --------------- | ----------------------------------------------------- |
|
||||||
|
| `stream` | `Iterable[Doc]` | A stream of documents. |
|
||||||
|
| _keyword-only_ | | |
|
||||||
|
| `batch_size` | int | The number of documents to buffer. Defaults to `128`. |
|
||||||
|
| **YIELDS** | `Doc` | The processed documents in order. |
|
||||||
|
|
||||||
|
## Pipe.begin_training {#begin_training tag="method"}
|
||||||
|
|
||||||
|
Initialize the pipe for training, using data examples if available. Return an
|
||||||
|
[`Optimizer`](https://thinc.ai/docs/api-optimizers) object.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> pipe = nlp.add_pipe("your_custom_pipe")
|
||||||
|
> optimizer = pipe.begin_training(pipeline=nlp.pipeline)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| -------------- | --------------------------------------------------- | ---------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `get_examples` | `Callable[[], Iterable[Example]]` | Optional function that returns gold-standard annotations in the form of [`Example`](/api/example) objects. |
|
||||||
|
| _keyword-only_ | | |
|
||||||
|
| `pipeline` | `List[Tuple[str, Callable]]` | Optional list of pipeline components that this component is part of. |
|
||||||
|
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | An optional optimizer. Will be created via [`create_optimizer`](/api/pipe#create_optimizer) if not set. |
|
||||||
|
| **RETURNS** | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||||
|
|
||||||
|
## Pipe.predict {#predict tag="method"}
|
||||||
|
|
||||||
|
Apply the pipeline's model to a batch of docs, without modifying them.
|
||||||
|
|
||||||
|
<Infobox variant="danger">
|
||||||
|
|
||||||
|
This method needs to be overwritten with your own custom `predict` method.
|
||||||
|
|
||||||
|
</Infobox>
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> pipe = nlp.add_pipe("your_custom_pipe")
|
||||||
|
> scores = pipe.predict([doc1, doc2])
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| ----------- | --------------- | ----------------------------------------- |
|
||||||
|
| `docs` | `Iterable[Doc]` | The documents to predict. |
|
||||||
|
| **RETURNS** | - | The model's prediction for each document. |
|
||||||
|
|
||||||
|
## Pipe.set_annotations {#set_annotations tag="method"}
|
||||||
|
|
||||||
|
Modify a batch of documents, using pre-computed scores.
|
||||||
|
|
||||||
|
<Infobox variant="danger">
|
||||||
|
|
||||||
|
This method needs to be overwritten with your own custom `set_annotations`
|
||||||
|
method.
|
||||||
|
|
||||||
|
</Infobox>
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> pipe = nlp.add_pipe("your_custom_pipe")
|
||||||
|
> scores = pipe.predict(docs)
|
||||||
|
> pipe.set_annotations(docs, scores)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| -------- | --------------- | ---------------------------------------------- |
|
||||||
|
| `docs` | `Iterable[Doc]` | The documents to modify. |
|
||||||
|
| `scores` | - | The scores to set, produced by `Pipe.predict`. |
|
||||||
|
|
||||||
|
## Pipe.update {#update tag="method"}
|
||||||
|
|
||||||
|
Learn from a batch of documents and gold-standard information, updating the
|
||||||
|
pipe's model. Delegates to [`predict`](/api/pipe#predict).
|
||||||
|
|
||||||
|
<Infobox variant="danger">
|
||||||
|
|
||||||
|
This method needs to be overwritten with your own custom `update` method.
|
||||||
|
|
||||||
|
</Infobox>
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> pipe = nlp.add_pipe("your_custom_pipe")
|
||||||
|
> optimizer = nlp.begin_training()
|
||||||
|
> losses = pipe.update(examples, sgd=optimizer)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| ----------------- | --------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `examples` | `Iterable[Example]` | A batch of [`Example`](/api/example) objects to learn from. |
|
||||||
|
| _keyword-only_ | | |
|
||||||
|
| `drop` | float | The dropout rate. |
|
||||||
|
| `set_annotations` | bool | Whether or not to update the `Example` objects with the predictions, delegating to [`set_annotations`](/api/pipe#set_annotations). |
|
||||||
|
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||||
|
| `losses` | `Dict[str, float]` | Optional record of the loss during training. Updated using the component name as the key. |
|
||||||
|
| **RETURNS** | `Dict[str, float]` | The updated `losses` dictionary. |
|
||||||
|
|
||||||
|
## Pipe.rehearse {#rehearse tag="method,experimental"}
|
||||||
|
|
||||||
|
Perform a "rehearsal" update from a batch of data. Rehearsal updates teach the
|
||||||
|
current model to make predictions similar to an initial model, to try to address
|
||||||
|
the "catastrophic forgetting" problem. This feature is experimental.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> pipe = nlp.add_pipe("your_custom_pipe")
|
||||||
|
> optimizer = nlp.begin_training()
|
||||||
|
> losses = pipe.rehearse(examples, sgd=optimizer)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| -------------- | --------------------------------------------------- | ----------------------------------------------------------------------------------------- |
|
||||||
|
| `examples` | `Iterable[Example]` | A batch of [`Example`](/api/example) objects to learn from. |
|
||||||
|
| _keyword-only_ | | |
|
||||||
|
| `drop` | float | The dropout rate. |
|
||||||
|
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||||
|
| `losses` | `Dict[str, float]` | Optional record of the loss during training. Updated using the component name as the key. |
|
||||||
|
| **RETURNS** | `Dict[str, float]` | The updated `losses` dictionary. |
|
||||||
|
|
||||||
|
## Pipe.get_loss {#get_loss tag="method"}
|
||||||
|
|
||||||
|
Find the loss and gradient of loss for the batch of documents and their
|
||||||
|
predicted scores.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> ner = nlp.add_pipe("ner")
|
||||||
|
> scores = ner.predict([eg.predicted for eg in examples])
|
||||||
|
> loss, d_loss = ner.get_loss(examples, scores)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| ----------- | --------------------- | --------------------------------------------------- |
|
||||||
|
| `examples` | `Iterable[Example]` | The batch of examples. |
|
||||||
|
| `scores` | | Scores representing the model's predictions. |
|
||||||
|
| **RETURNS** | `Tuple[float, float]` | The loss and the gradient, i.e. `(loss, gradient)`. |
|
||||||
|
|
||||||
|
## Pipe.score {#score tag="method" new="3"}
|
||||||
|
|
||||||
|
Score a batch of examples.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> scores = pipe.score(examples)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| ----------- | ------------------- | --------------------------------------------------------- |
|
||||||
|
| `examples` | `Iterable[Example]` | The examples to score. |
|
||||||
|
| **RETURNS** | `Dict[str, Any]` | The scores, e.g. produced by the [`Scorer`](/api/scorer). |
|
||||||
|
|
||||||
|
## Pipe.create_optimizer {#create_optimizer tag="method"}
|
||||||
|
|
||||||
|
Create an optimizer for the pipeline component. Defaults to
|
||||||
|
[`Adam`](https://thinc.ai/docs/api-optimizers#adam) with default settings.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> pipe = nlp.add_pipe("your_custom_pipe")
|
||||||
|
> optimizer = pipe.create_optimizer()
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| ----------- | --------------------------------------------------- | -------------- |
|
||||||
|
| **RETURNS** | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||||
|
|
||||||
|
## Pipe.add_label {#add_label tag="method"}
|
||||||
|
|
||||||
|
Add a new label to the pipe. It's possible to extend pretrained models with new
|
||||||
|
labels, but care should be taken to avoid the "catastrophic forgetting" problem.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> pipe = nlp.add_pipe("your_custom_pipe")
|
||||||
|
> pipe.add_label("MY_LABEL")
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| ----------- | ---- | --------------------------------------------------- |
|
||||||
|
| `label` | str | The label to add. |
|
||||||
|
| **RETURNS** | int | `0` if the label is already present, otherwise `1`. |
|
||||||
|
|
||||||
|
## Pipe.use_params {#use_params tag="method, contextmanager"}
|
||||||
|
|
||||||
|
Modify the pipe's model, to use the given parameter values. At the end of the
|
||||||
|
context, the original parameters are restored.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> pipe = nlp.add_pipe("your_custom_pipe")
|
||||||
|
> with pipe.use_params(optimizer.averages):
|
||||||
|
> pipe.to_disk("/best_model")
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| -------- | ---- | ----------------------------------------- |
|
||||||
|
| `params` | dict | The parameter values to use in the model. |
|
||||||
|
|
||||||
|
## Pipe.to_disk {#to_disk tag="method"}
|
||||||
|
|
||||||
|
Serialize the pipe to disk.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> pipe = nlp.add_pipe("your_custom_pipe")
|
||||||
|
> pipe.to_disk("/path/to/pipe")
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| --------- | --------------- | --------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||||
|
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||||
|
|
||||||
|
## Pipe.from_disk {#from_disk tag="method"}
|
||||||
|
|
||||||
|
Load the pipe from disk. Modifies the object in place and returns it.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> pipe = nlp.add_pipe("your_custom_pipe")
|
||||||
|
> pipe.from_disk("/path/to/pipe")
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| ----------- | --------------- | -------------------------------------------------------------------------- |
|
||||||
|
| `path` | str / `Path` | A path to a directory. Paths may be either strings or `Path`-like objects. |
|
||||||
|
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||||
|
| **RETURNS** | `Pipe` | The modified pipe. |
|
||||||
|
|
||||||
|
## Pipe.to_bytes {#to_bytes tag="method"}
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> pipe = nlp.add_pipe("your_custom_pipe")
|
||||||
|
> pipe_bytes = pipe.to_bytes()
|
||||||
|
> ```
|
||||||
|
|
||||||
|
Serialize the pipe to a bytestring.
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| ----------- | --------------- | ------------------------------------------------------------------------- |
|
||||||
|
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||||
|
| **RETURNS** | bytes | The serialized form of the pipe. |
|
||||||
|
|
||||||
|
## Pipe.from_bytes {#from_bytes tag="method"}
|
||||||
|
|
||||||
|
Load the pipe from a bytestring. Modifies the object in place and returns it.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> pipe_bytes = pipe.to_bytes()
|
||||||
|
> pipe = nlp.add_pipe("your_custom_pipe")
|
||||||
|
> pipe.from_bytes(pipe_bytes)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| ------------ | --------------- | ------------------------------------------------------------------------- |
|
||||||
|
| `bytes_data` | bytes | The data to load from. |
|
||||||
|
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||||
|
| **RETURNS** | `Pipe` | The pipe. |
|
||||||
|
|
||||||
|
## Serialization fields {#serialization-fields}
|
||||||
|
|
||||||
|
During serialization, spaCy will export several data fields used to restore
|
||||||
|
different aspects of the object. If needed, you can exclude them from
|
||||||
|
serialization by passing in the string names via the `exclude` argument.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> data = pipe.to_disk("/path", exclude=["vocab"])
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ------- | -------------------------------------------------------------- |
|
||||||
|
| `vocab` | The shared [`Vocab`](/api/vocab). |
|
||||||
|
| `cfg` | The config file. You usually don't want to exclude this. |
|
||||||
|
| `model` | The binary model data. You usually don't want to exclude this. |
|
||||||
|
|
|
@ -265,19 +265,20 @@ Create an optimizer for the pipeline component.
|
||||||
|
|
||||||
## SentenceRecognizer.use_params {#use_params tag="method, contextmanager"}
|
## SentenceRecognizer.use_params {#use_params tag="method, contextmanager"}
|
||||||
|
|
||||||
Modify the pipe's model, to use the given parameter values.
|
Modify the pipe's model, to use the given parameter values. At the end of the
|
||||||
|
context, the original parameters are restored.
|
||||||
|
|
||||||
> #### Example
|
> #### Example
|
||||||
>
|
>
|
||||||
> ```python
|
> ```python
|
||||||
> senter = nlp.add_pipe("senter")
|
> senter = nlp.add_pipe("senter")
|
||||||
> with senter.use_params():
|
> with senter.use_params(optimizer.averages):
|
||||||
> senter.to_disk("/best_model")
|
> senter.to_disk("/best_model")
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Type | Description |
|
| Name | Type | Description |
|
||||||
| -------- | ---- | ---------------------------------------------------------------------------------------------------------- |
|
| -------- | ---- | ----------------------------------------- |
|
||||||
| `params` | - | The parameter values to use in the model. At the end of the context, the original parameters are restored. |
|
| `params` | dict | The parameter values to use in the model. |
|
||||||
|
|
||||||
## SentenceRecognizer.to_disk {#to_disk tag="method"}
|
## SentenceRecognizer.to_disk {#to_disk tag="method"}
|
||||||
|
|
||||||
|
|
|
@ -263,19 +263,20 @@ Create an optimizer for the pipeline component.
|
||||||
|
|
||||||
## Tagger.use_params {#use_params tag="method, contextmanager"}
|
## Tagger.use_params {#use_params tag="method, contextmanager"}
|
||||||
|
|
||||||
Modify the pipe's model, to use the given parameter values.
|
Modify the pipe's model, to use the given parameter values. At the end of the
|
||||||
|
context, the original parameters are restored.
|
||||||
|
|
||||||
> #### Example
|
> #### Example
|
||||||
>
|
>
|
||||||
> ```python
|
> ```python
|
||||||
> tagger = nlp.add_pipe("tagger")
|
> tagger = nlp.add_pipe("tagger")
|
||||||
> with tagger.use_params():
|
> with tagger.use_params(optimizer.averages):
|
||||||
> tagger.to_disk("/best_model")
|
> tagger.to_disk("/best_model")
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Type | Description |
|
| Name | Type | Description |
|
||||||
| -------- | ---- | ---------------------------------------------------------------------------------------------------------- |
|
| -------- | ---- | ----------------------------------------- |
|
||||||
| `params` | - | The parameter values to use in the model. At the end of the context, the original parameters are restored. |
|
| `params` | dict | The parameter values to use in the model. |
|
||||||
|
|
||||||
## Tagger.add_label {#add_label tag="method"}
|
## Tagger.add_label {#add_label tag="method"}
|
||||||
|
|
||||||
|
@ -289,10 +290,11 @@ Add a new label to the pipe.
|
||||||
> tagger.add_label("MY_LABEL", {POS: "NOUN"})
|
> tagger.add_label("MY_LABEL", {POS: "NOUN"})
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Type | Description |
|
| Name | Type | Description |
|
||||||
| -------- | ---------------- | --------------------------------------------------------------- |
|
| ----------- | ---------------- | --------------------------------------------------------------- |
|
||||||
| `label` | str | The label to add. |
|
| `label` | str | The label to add. |
|
||||||
| `values` | `Dict[int, str]` | Optional values to map to the label, e.g. a tag map dictionary. |
|
| `values` | `Dict[int, str]` | Optional values to map to the label, e.g. a tag map dictionary. |
|
||||||
|
| **RETURNS** | int | `0` if the label is already present, otherwise `1`. |
|
||||||
|
|
||||||
## Tagger.to_disk {#to_disk tag="method"}
|
## Tagger.to_disk {#to_disk tag="method"}
|
||||||
|
|
||||||
|
|
|
@ -262,7 +262,8 @@ Score a batch of examples.
|
||||||
|
|
||||||
| Name | Type | Description |
|
| Name | Type | Description |
|
||||||
| ---------------- | ------------------- | ---------------------------------------------------------------------- |
|
| ---------------- | ------------------- | ---------------------------------------------------------------------- |
|
||||||
| `examples` | `Iterable[Example]` | The examples to score. | _keyword-only_ | | |
|
| `examples` | `Iterable[Example]` | The examples to score. |
|
||||||
|
| _keyword-only_ | | |
|
||||||
| `positive_label` | str | Optional positive label. |
|
| `positive_label` | str | Optional positive label. |
|
||||||
| **RETURNS** | `Dict[str, Any]` | The scores, produced by [`Scorer.score_cats`](/api/scorer#score_cats). |
|
| **RETURNS** | `Dict[str, Any]` | The scores, produced by [`Scorer.score_cats`](/api/scorer#score_cats). |
|
||||||
|
|
||||||
|
@ -292,9 +293,10 @@ Add a new label to the pipe.
|
||||||
> textcat.add_label("MY_LABEL")
|
> textcat.add_label("MY_LABEL")
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Type | Description |
|
| Name | Type | Description |
|
||||||
| ------- | ---- | ----------------- |
|
| ----------- | ---- | --------------------------------------------------- |
|
||||||
| `label` | str | The label to add. |
|
| `label` | str | The label to add. |
|
||||||
|
| **RETURNS** | int | `0` if the label is already present, otherwise `1`. |
|
||||||
|
|
||||||
## TextCategorizer.use_params {#use_params tag="method, contextmanager"}
|
## TextCategorizer.use_params {#use_params tag="method, contextmanager"}
|
||||||
|
|
||||||
|
@ -304,13 +306,13 @@ Modify the pipe's model, to use the given parameter values.
|
||||||
>
|
>
|
||||||
> ```python
|
> ```python
|
||||||
> textcat = nlp.add_pipe("textcat")
|
> textcat = nlp.add_pipe("textcat")
|
||||||
> with textcat.use_params():
|
> with textcat.use_params(optimizer.averages):
|
||||||
> textcat.to_disk("/best_model")
|
> textcat.to_disk("/best_model")
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Type | Description |
|
| Name | Type | Description |
|
||||||
| -------- | ---- | ---------------------------------------------------------------------------------------------------------- |
|
| -------- | ---- | ----------------------------------------- |
|
||||||
| `params` | - | The parameter values to use in the model. At the end of the context, the original parameters are restored. |
|
| `params` | dict | The parameter values to use in the model. |
|
||||||
|
|
||||||
## TextCategorizer.to_disk {#to_disk tag="method"}
|
## TextCategorizer.to_disk {#to_disk tag="method"}
|
||||||
|
|
||||||
|
|
|
@ -8,4 +8,295 @@ api_string_name: tok2vec
|
||||||
api_trainable: true
|
api_trainable: true
|
||||||
---
|
---
|
||||||
|
|
||||||
TODO:
|
<!-- TODO: intro describing component -->
|
||||||
|
|
||||||
|
## Config and implementation {#config}
|
||||||
|
|
||||||
|
The default config is defined by the pipeline component factory and describes
|
||||||
|
how the component should be configured. You can override its settings via the
|
||||||
|
`config` argument on [`nlp.add_pipe`](/api/language#add_pipe) or in your
|
||||||
|
[`config.cfg` for training](/usage/training#config). See the
|
||||||
|
[model architectures](/api/architectures) documentation for details on the
|
||||||
|
architectures and their arguments and hyperparameters.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
|
||||||
|
> config = {"model": DEFAULT_TOK2VEC_MODEL}
|
||||||
|
> nlp.add_pipe("tok2vec", config=config)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Setting | Type | Description | Default |
|
||||||
|
| ------- | ------------------------------------------ | ----------------- | ----------------------------------------------- |
|
||||||
|
| `model` | [`Model`](https://thinc.ai/docs/api-model) | The model to use. | [HashEmbedCNN](/api/architectures#HashEmbedCNN) |
|
||||||
|
|
||||||
|
```python
|
||||||
|
https://github.com/explosion/spaCy/blob/develop/spacy/pipeline/tok2vec.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Tok2Vec.\_\_init\_\_ {#init tag="method"}
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> # Construction via add_pipe with default model
|
||||||
|
> tok2vec = nlp.add_pipe("tok2vec")
|
||||||
|
>
|
||||||
|
> # Construction via add_pipe with custom model
|
||||||
|
> config = {"model": {"@architectures": "my_tok2vec"}}
|
||||||
|
> parser = nlp.add_pipe("tok2vec", config=config)
|
||||||
|
>
|
||||||
|
> # Construction from class
|
||||||
|
> from spacy.pipeline import Tok2Vec
|
||||||
|
> tok2vec = Tok2Vec(nlp.vocab, model)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
Create a new pipeline instance. In your application, you would normally use a
|
||||||
|
shortcut for this and instantiate the component using its string name and
|
||||||
|
[`nlp.add_pipe`](/api/language#create_pipe).
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| ------- | ------------------------------------------ | ------------------------------------------------------------------------------------------- |
|
||||||
|
| `vocab` | `Vocab` | The shared vocabulary. |
|
||||||
|
| `model` | [`Model`](https://thinc.ai/docs/api-model) | The Thinc [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. |
|
||||||
|
| `name` | str | String name of the component instance. Used to add entries to the `losses` during training. |
|
||||||
|
|
||||||
|
## Tok2Vec.\_\_call\_\_ {#call tag="method"}
|
||||||
|
|
||||||
|
Apply the pipe to one document. The document is modified in place, and returned.
|
||||||
|
This usually happens under the hood when the `nlp` object is called on a text
|
||||||
|
and all pipeline components are applied to the `Doc` in order. Both
|
||||||
|
[`__call__`](/api/tok2vec#call) and [`pipe`](/api/tok2vec#pipe) delegate to the
|
||||||
|
[`predict`](/api/tok2vec#predict) and
|
||||||
|
[`set_annotations`](/api/tok2vec#set_annotations) methods.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> doc = nlp("This is a sentence.")
|
||||||
|
> tok2vec = nlp.add_pipe("tok2vec")
|
||||||
|
> # This usually happens under the hood
|
||||||
|
> processed = tok2vec(doc)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| ----------- | ----- | ------------------------ |
|
||||||
|
| `doc` | `Doc` | The document to process. |
|
||||||
|
| **RETURNS** | `Doc` | The processed document. |
|
||||||
|
|
||||||
|
## Tok2Vec.pipe {#pipe tag="method"}
|
||||||
|
|
||||||
|
Apply the pipe to a stream of documents. This usually happens under the hood
|
||||||
|
when the `nlp` object is called on a text and all pipeline components are
|
||||||
|
applied to the `Doc` in order. Both [`__call__`](/api/tok2vec#call) and
|
||||||
|
[`pipe`](/api/tok2vec#pipe) delegate to the [`predict`](/api/tok2vec#predict)
|
||||||
|
and [`set_annotations`](/api/tok2vec#set_annotations) methods.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> tok2vec = nlp.add_pipe("tok2vec")
|
||||||
|
> for doc in tok2vec.pipe(docs, batch_size=50):
|
||||||
|
> pass
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| -------------- | --------------- | ----------------------------------------------------- |
|
||||||
|
| `stream` | `Iterable[Doc]` | A stream of documents. |
|
||||||
|
| _keyword-only_ | | |
|
||||||
|
| `batch_size` | int | The number of documents to buffer. Defaults to `128`. |
|
||||||
|
| **YIELDS** | `Doc` | The processed documents in order. |
|
||||||
|
|
||||||
|
## Tok2Vec.begin_training {#begin_training tag="method"}
|
||||||
|
|
||||||
|
Initialize the pipe for training, using data examples if available. Return an
|
||||||
|
[`Optimizer`](https://thinc.ai/docs/api-optimizers) object.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> tok2vec = nlp.add_pipe("tok2vec")
|
||||||
|
> optimizer = tok2vec.begin_training(pipeline=nlp.pipeline)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| -------------- | --------------------------------------------------- | ---------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `get_examples` | `Callable[[], Iterable[Example]]` | Optional function that returns gold-standard annotations in the form of [`Example`](/api/example) objects. |
|
||||||
|
| _keyword-only_ | | |
|
||||||
|
| `pipeline` | `List[Tuple[str, Callable]]` | Optional list of pipeline components that this component is part of. |
|
||||||
|
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | An optional optimizer. Will be created via [`create_optimizer`](/api/tok2vec#create_optimizer) if not set. |
|
||||||
|
| **RETURNS** | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||||
|
|
||||||
|
## Tok2Vec.predict {#predict tag="method"}
|
||||||
|
|
||||||
|
Apply the pipeline's model to a batch of docs, without modifying them.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> tok2vec = nlp.add_pipe("tok2vec")
|
||||||
|
> scores = tok2vec.predict([doc1, doc2])
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| ----------- | --------------- | ----------------------------------------- |
|
||||||
|
| `docs` | `Iterable[Doc]` | The documents to predict. |
|
||||||
|
| **RETURNS** | - | The model's prediction for each document. |
|
||||||
|
|
||||||
|
## Tok2Vec.set_annotations {#set_annotations tag="method"}
|
||||||
|
|
||||||
|
Modify a batch of documents, using pre-computed scores.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> tok2vec = nlp.add_pipe("tok2vec")
|
||||||
|
> scores = tok2vec.predict(docs)
|
||||||
|
> tok2vec.set_annotations(docs, scores)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| -------- | --------------- | ------------------------------------------------- |
|
||||||
|
| `docs` | `Iterable[Doc]` | The documents to modify. |
|
||||||
|
| `scores` | - | The scores to set, produced by `Tok2Vec.predict`. |
|
||||||
|
|
||||||
|
## Tok2Vec.update {#update tag="method"}
|
||||||
|
|
||||||
|
Learn from a batch of documents and gold-standard information, updating the
|
||||||
|
pipe's model. Delegates to [`predict`](/api/tok2vec#predict).
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> tok2vec = nlp.add_pipe("tok2vec")
|
||||||
|
> optimizer = nlp.begin_training()
|
||||||
|
> losses = tok2vec.update(examples, sgd=optimizer)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| ----------------- | --------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `examples` | `Iterable[Example]` | A batch of [`Example`](/api/example) objects to learn from. |
|
||||||
|
| _keyword-only_ | | |
|
||||||
|
| `drop` | float | The dropout rate. |
|
||||||
|
| `set_annotations` | bool | Whether or not to update the `Example` objects with the predictions, delegating to [`set_annotations`](/api/tok2vec#set_annotations). |
|
||||||
|
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||||
|
| `losses` | `Dict[str, float]` | Optional record of the loss during training. Updated using the component name as the key. |
|
||||||
|
| **RETURNS** | `Dict[str, float]` | The updated `losses` dictionary. |
|
||||||
|
|
||||||
|
## Tok2Vec.create_optimizer {#create_optimizer tag="method"}
|
||||||
|
|
||||||
|
Create an optimizer for the pipeline component.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> tok2vec = nlp.add_pipe("tok2vec")
|
||||||
|
> optimizer = tok2vec.create_optimizer()
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| ----------- | --------------------------------------------------- | -------------- |
|
||||||
|
| **RETURNS** | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||||
|
|
||||||
|
## Tok2Vec.use_params {#use_params tag="method, contextmanager"}
|
||||||
|
|
||||||
|
Modify the pipe's model, to use the given parameter values. At the end of the
|
||||||
|
context, the original parameters are restored.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> tok2vec = nlp.add_pipe("tok2vec")
|
||||||
|
> with tok2vec.use_params(optimizer.averages):
|
||||||
|
> tok2vec.to_disk("/best_model")
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| -------- | ---- | ----------------------------------------- |
|
||||||
|
| `params` | dict | The parameter values to use in the model. |
|
||||||
|
|
||||||
|
## Tok2Vec.to_disk {#to_disk tag="method"}
|
||||||
|
|
||||||
|
Serialize the pipe to disk.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> tok2vec = nlp.add_pipe("tok2vec")
|
||||||
|
> tok2vec.to_disk("/path/to/tok2vec")
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| --------- | --------------- | --------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||||
|
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||||
|
|
||||||
|
## Tok2Vec.from_disk {#from_disk tag="method"}
|
||||||
|
|
||||||
|
Load the pipe from disk. Modifies the object in place and returns it.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> tok2vec = nlp.add_pipe("tok2vec")
|
||||||
|
> tok2vec.from_disk("/path/to/tok2vec")
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| ----------- | --------------- | -------------------------------------------------------------------------- |
|
||||||
|
| `path` | str / `Path` | A path to a directory. Paths may be either strings or `Path`-like objects. |
|
||||||
|
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||||
|
| **RETURNS** | `Tok2Vec` | The modified `Tok2Vec` object. |
|
||||||
|
|
||||||
|
## Tok2Vec.to_bytes {#to_bytes tag="method"}
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> tok2vec = nlp.add_pipe("tok2vec")
|
||||||
|
> tok2vec_bytes = tok2vec.to_bytes()
|
||||||
|
> ```
|
||||||
|
|
||||||
|
Serialize the pipe to a bytestring.
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| ----------- | --------------- | ------------------------------------------------------------------------- |
|
||||||
|
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||||
|
| **RETURNS** | bytes | The serialized form of the `Tok2Vec` object. |
|
||||||
|
|
||||||
|
## Tok2Vec.from_bytes {#from_bytes tag="method"}
|
||||||
|
|
||||||
|
Load the pipe from a bytestring. Modifies the object in place and returns it.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> tok2vec_bytes = tok2vec.to_bytes()
|
||||||
|
> tok2vec = nlp.add_pipe("tok2vec")
|
||||||
|
> tok2vec.from_bytes(tok2vec_bytes)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| ------------ | --------------- | ------------------------------------------------------------------------- |
|
||||||
|
| `bytes_data` | bytes | The data to load from. |
|
||||||
|
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||||
|
| **RETURNS** | `Tok2Vec` | The `Tok2Vec` object. |
|
||||||
|
|
||||||
|
## Serialization fields {#serialization-fields}
|
||||||
|
|
||||||
|
During serialization, spaCy will export several data fields used to restore
|
||||||
|
different aspects of the object. If needed, you can exclude them from
|
||||||
|
serialization by passing in the string names via the `exclude` argument.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> data = tok2vec.to_disk("/path", exclude=["vocab"])
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ------- | -------------------------------------------------------------- |
|
||||||
|
| `vocab` | The shared [`Vocab`](/api/vocab). |
|
||||||
|
| `cfg` | The config file. You usually don't want to exclude this. |
|
||||||
|
| `model` | The binary model data. You usually don't want to exclude this. |
|
||||||
|
|
Loading…
Reference in New Issue
Block a user