2022-09-12 15:55:41 +03:00
|
|
|
# cython: infer_types=True, profile=True, binding=True
|
2023-06-26 12:41:03 +03:00
|
|
|
from typing import Callable, Dict, Iterable, Iterator, Optional, Tuple
|
|
|
|
|
|
|
|
import srsly
|
|
|
|
from thinc.api import Model, Optimizer, set_dropout_rate
|
2020-10-08 22:33:49 +03:00
|
|
|
|
|
|
|
from ..tokens.doc cimport Doc
|
|
|
|
|
|
|
|
from .. import util
|
2023-07-19 18:41:29 +03:00
|
|
|
from ..errors import Errors
|
2020-10-08 22:33:49 +03:00
|
|
|
from ..language import Language
|
2023-06-26 12:41:03 +03:00
|
|
|
from ..training import Example, validate_distillation_examples, validate_examples
|
|
|
|
from ..vocab import Vocab
|
|
|
|
from .pipe import Pipe, deserialize_config
|
2020-10-08 22:33:49 +03:00
|
|
|
|
2020-10-10 19:55:07 +03:00
|
|
|
|
2020-10-08 22:33:49 +03:00
|
|
|
cdef class TrainablePipe(Pipe):
|
|
|
|
"""This class is a base class and not instantiated directly. Trainable
|
|
|
|
pipeline components like the EntityRecognizer or TextCategorizer inherit
|
|
|
|
from it and it defines the interface that components should follow to
|
|
|
|
function as trainable components in a spaCy pipeline.
|
|
|
|
|
2021-01-30 12:09:38 +03:00
|
|
|
DOCS: https://spacy.io/api/pipe
|
2020-10-08 22:33:49 +03:00
|
|
|
"""
|
|
|
|
def __init__(self, vocab: Vocab, model: Model, name: str, **cfg):
|
|
|
|
"""Initialize a pipeline component.
|
|
|
|
|
|
|
|
vocab (Vocab): The shared vocabulary.
|
|
|
|
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
|
|
|
name (str): The component instance name.
|
2021-01-29 03:51:21 +03:00
|
|
|
**cfg: Additional settings and config parameters.
|
2020-10-08 22:33:49 +03:00
|
|
|
|
2021-01-30 12:09:38 +03:00
|
|
|
DOCS: https://spacy.io/api/pipe#init
|
2020-10-08 22:33:49 +03:00
|
|
|
"""
|
|
|
|
self.vocab = vocab
|
|
|
|
self.model = model
|
|
|
|
self.name = name
|
|
|
|
self.cfg = dict(cfg)
|
|
|
|
|
|
|
|
def __call__(self, Doc doc) -> Doc:
|
|
|
|
"""Apply the pipe to one document. The document is modified in place,
|
|
|
|
and returned. This usually happens under the hood when the nlp object
|
|
|
|
is called on a text and all components are applied to the Doc.
|
|
|
|
|
|
|
|
docs (Doc): The Doc to process.
|
|
|
|
RETURNS (Doc): The processed Doc.
|
|
|
|
|
2021-01-30 12:09:38 +03:00
|
|
|
DOCS: https://spacy.io/api/pipe#call
|
2020-10-08 22:33:49 +03:00
|
|
|
"""
|
2021-01-29 03:51:21 +03:00
|
|
|
error_handler = self.get_error_handler()
|
|
|
|
try:
|
|
|
|
scores = self.predict([doc])
|
|
|
|
self.set_annotations([doc], scores)
|
|
|
|
return doc
|
|
|
|
except Exception as e:
|
|
|
|
error_handler(self.name, self, [doc], e)
|
2020-10-08 22:33:49 +03:00
|
|
|
|
2023-01-16 12:25:53 +03:00
|
|
|
def distill(self,
|
2023-07-19 18:41:29 +03:00
|
|
|
teacher_pipe: Optional["TrainablePipe"],
|
|
|
|
examples: Iterable["Example"],
|
|
|
|
*,
|
|
|
|
drop: float = 0.0,
|
|
|
|
sgd: Optional[Optimizer] = None,
|
|
|
|
losses: Optional[Dict[str, float]] = None
|
|
|
|
) -> Dict[str, float]:
|
2023-01-16 12:25:53 +03:00
|
|
|
"""Train a pipe (the student) on the predictions of another pipe
|
|
|
|
(the teacher). The student is typically trained on the probability
|
|
|
|
distribution of the teacher, but details may differ per pipe.
|
|
|
|
|
|
|
|
teacher_pipe (Optional[TrainablePipe]): The teacher pipe to learn
|
|
|
|
from.
|
|
|
|
examples (Iterable[Example]): Distillation examples. The reference
|
2023-01-30 14:44:11 +03:00
|
|
|
(teacher) and predicted (student) docs must have the same number of
|
|
|
|
tokens and the same orthography.
|
2023-01-16 12:25:53 +03:00
|
|
|
drop (float): dropout rate.
|
|
|
|
sgd (Optional[Optimizer]): An optimizer. Will be created via
|
|
|
|
create_optimizer if not set.
|
|
|
|
losses (Optional[Dict[str, float]]): Optional record of loss during
|
|
|
|
distillation.
|
|
|
|
RETURNS: The updated losses dictionary.
|
2023-07-19 17:37:31 +03:00
|
|
|
|
2023-01-16 12:25:53 +03:00
|
|
|
DOCS: https://spacy.io/api/pipe#distill
|
|
|
|
"""
|
|
|
|
# By default we require a teacher pipe, but there are downstream
|
|
|
|
# implementations that don't require a pipe.
|
|
|
|
if teacher_pipe is None:
|
|
|
|
raise ValueError(Errors.E4002.format(name=self.name))
|
|
|
|
if losses is None:
|
|
|
|
losses = {}
|
|
|
|
losses.setdefault(self.name, 0.0)
|
|
|
|
validate_distillation_examples(examples, "TrainablePipe.distill")
|
|
|
|
set_dropout_rate(self.model, drop)
|
|
|
|
for node in teacher_pipe.model.walk():
|
|
|
|
if node.name == "softmax":
|
|
|
|
node.attrs["softmax_normalize"] = True
|
|
|
|
teacher_scores = teacher_pipe.model.predict([eg.reference for eg in examples])
|
|
|
|
student_scores, bp_student_scores = self.model.begin_update([eg.predicted for eg in examples])
|
|
|
|
loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores)
|
|
|
|
bp_student_scores(d_scores)
|
|
|
|
if sgd is not None:
|
|
|
|
self.finish_update(sgd)
|
|
|
|
losses[self.name] += loss
|
|
|
|
return losses
|
|
|
|
|
2023-07-19 13:03:31 +03:00
|
|
|
def pipe(self, stream: Iterable[Doc], *, batch_size: int = 128) -> Iterator[Doc]:
|
2020-10-08 22:33:49 +03:00
|
|
|
"""Apply the pipe to a stream of documents. This usually happens under
|
|
|
|
the hood when the nlp object is called on a text and all components are
|
|
|
|
applied to the Doc.
|
|
|
|
|
|
|
|
stream (Iterable[Doc]): A stream of documents.
|
|
|
|
batch_size (int): The number of documents to buffer.
|
2021-01-29 03:51:21 +03:00
|
|
|
error_handler (Callable[[str, List[Doc], Exception], Any]): Function that
|
|
|
|
deals with a failing batch of documents. The default function just reraises
|
|
|
|
the exception.
|
2020-10-08 22:33:49 +03:00
|
|
|
YIELDS (Doc): Processed documents in order.
|
|
|
|
|
2021-01-30 12:09:38 +03:00
|
|
|
DOCS: https://spacy.io/api/pipe#pipe
|
2020-10-08 22:33:49 +03:00
|
|
|
"""
|
2021-01-29 03:51:21 +03:00
|
|
|
error_handler = self.get_error_handler()
|
2020-10-08 22:33:49 +03:00
|
|
|
for docs in util.minibatch(stream, size=batch_size):
|
2021-01-29 03:51:21 +03:00
|
|
|
try:
|
|
|
|
scores = self.predict(docs)
|
|
|
|
self.set_annotations(docs, scores)
|
|
|
|
yield from docs
|
|
|
|
except Exception as e:
|
|
|
|
error_handler(self.name, self, docs, e)
|
2020-10-08 22:33:49 +03:00
|
|
|
|
|
|
|
def predict(self, docs: Iterable[Doc]):
|
|
|
|
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
|
|
|
Returns a single tensor for a batch of documents.
|
|
|
|
|
|
|
|
docs (Iterable[Doc]): The documents to predict.
|
|
|
|
RETURNS: Vector representations of the predictions.
|
|
|
|
|
2021-01-30 12:09:38 +03:00
|
|
|
DOCS: https://spacy.io/api/pipe#predict
|
2020-10-08 22:33:49 +03:00
|
|
|
"""
|
|
|
|
raise NotImplementedError(Errors.E931.format(parent="TrainablePipe", method="predict", name=self.name))
|
|
|
|
|
|
|
|
def set_annotations(self, docs: Iterable[Doc], scores):
|
|
|
|
"""Modify a batch of documents, using pre-computed scores.
|
|
|
|
|
|
|
|
docs (Iterable[Doc]): The documents to modify.
|
|
|
|
scores: The scores to assign.
|
|
|
|
|
2021-01-30 12:09:38 +03:00
|
|
|
DOCS: https://spacy.io/api/pipe#set_annotations
|
2020-10-08 22:33:49 +03:00
|
|
|
"""
|
|
|
|
raise NotImplementedError(Errors.E931.format(parent="TrainablePipe", method="set_annotations", name=self.name))
|
|
|
|
|
|
|
|
def update(self,
|
|
|
|
examples: Iterable["Example"],
|
2021-06-24 13:35:27 +03:00
|
|
|
*,
|
2023-07-19 13:03:31 +03:00
|
|
|
drop: float = 0.0,
|
|
|
|
sgd: Optimizer = None,
|
|
|
|
losses: Optional[Dict[str, float]] = None) -> Dict[str, float]:
|
2020-10-08 22:33:49 +03:00
|
|
|
"""Learn from a batch of documents and gold-standard information,
|
2021-01-25 17:18:45 +03:00
|
|
|
updating the pipe's model. Delegates to predict and get_loss.
|
2020-10-08 22:33:49 +03:00
|
|
|
|
|
|
|
examples (Iterable[Example]): A batch of Example objects.
|
|
|
|
drop (float): The dropout rate.
|
|
|
|
sgd (thinc.api.Optimizer): The optimizer.
|
|
|
|
losses (Dict[str, float]): Optional record of the loss during training.
|
|
|
|
Updated using the component name as the key.
|
|
|
|
RETURNS (Dict[str, float]): The updated losses dictionary.
|
|
|
|
|
2021-01-30 12:09:38 +03:00
|
|
|
DOCS: https://spacy.io/api/pipe#update
|
2020-10-08 22:33:49 +03:00
|
|
|
"""
|
|
|
|
if losses is None:
|
|
|
|
losses = {}
|
|
|
|
if not hasattr(self, "model") or self.model in (None, True, False):
|
|
|
|
return losses
|
|
|
|
losses.setdefault(self.name, 0.0)
|
|
|
|
validate_examples(examples, "TrainablePipe.update")
|
|
|
|
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
|
|
|
|
# Handle cases where there are no tokens in any docs.
|
2020-10-14 16:00:49 +03:00
|
|
|
return losses
|
2020-10-08 22:33:49 +03:00
|
|
|
set_dropout_rate(self.model, drop)
|
|
|
|
scores, bp_scores = self.model.begin_update([eg.predicted for eg in examples])
|
|
|
|
loss, d_scores = self.get_loss(examples, scores)
|
|
|
|
bp_scores(d_scores)
|
|
|
|
if sgd not in (None, False):
|
|
|
|
self.finish_update(sgd)
|
|
|
|
losses[self.name] += loss
|
|
|
|
return losses
|
|
|
|
|
|
|
|
def rehearse(self,
|
|
|
|
examples: Iterable[Example],
|
|
|
|
*,
|
2023-07-19 13:03:31 +03:00
|
|
|
sgd: Optimizer = None,
|
|
|
|
losses: Dict[str, float] = None,
|
2020-10-08 22:33:49 +03:00
|
|
|
**config) -> Dict[str, float]:
|
|
|
|
"""Perform a "rehearsal" update from a batch of data. Rehearsal updates
|
|
|
|
teach the current model to make predictions similar to an initial model,
|
|
|
|
to try to address the "catastrophic forgetting" problem. This feature is
|
|
|
|
experimental.
|
|
|
|
|
|
|
|
examples (Iterable[Example]): A batch of Example objects.
|
|
|
|
sgd (thinc.api.Optimizer): The optimizer.
|
|
|
|
losses (Dict[str, float]): Optional record of the loss during training.
|
|
|
|
Updated using the component name as the key.
|
|
|
|
RETURNS (Dict[str, float]): The updated losses dictionary.
|
|
|
|
|
2021-01-30 12:09:38 +03:00
|
|
|
DOCS: https://spacy.io/api/pipe#rehearse
|
2020-10-08 22:33:49 +03:00
|
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
|
|
def get_loss(self, examples: Iterable[Example], scores) -> Tuple[float, float]:
|
|
|
|
"""Find the loss and gradient of loss for the batch of documents and
|
|
|
|
their predicted scores.
|
|
|
|
|
|
|
|
examples (Iterable[Examples]): The batch of examples.
|
|
|
|
scores: Scores representing the model's predictions.
|
|
|
|
RETURNS (Tuple[float, float]): The loss and the gradient.
|
|
|
|
|
2021-01-30 12:09:38 +03:00
|
|
|
DOCS: https://spacy.io/api/pipe#get_loss
|
2020-10-08 22:33:49 +03:00
|
|
|
"""
|
|
|
|
raise NotImplementedError(Errors.E931.format(parent="TrainablePipe", method="get_loss", name=self.name))
|
|
|
|
|
2023-01-16 12:25:53 +03:00
|
|
|
def get_teacher_student_loss(self, teacher_scores, student_scores):
|
|
|
|
"""Calculate the loss and its gradient for a batch of student
|
|
|
|
scores, relative to teacher scores.
|
|
|
|
|
|
|
|
teacher_scores: Scores representing the teacher model's predictions.
|
|
|
|
student_scores: Scores representing the student model's predictions.
|
|
|
|
|
|
|
|
RETURNS (Tuple[float, float]): The loss and the gradient.
|
2023-07-19 17:37:31 +03:00
|
|
|
|
2023-01-16 12:25:53 +03:00
|
|
|
DOCS: https://spacy.io/api/pipe#get_teacher_student_loss
|
|
|
|
"""
|
|
|
|
raise NotImplementedError(Errors.E931.format(parent="TrainablePipe", method="get_teacher_student_loss", name=self.name))
|
|
|
|
|
2020-10-08 22:33:49 +03:00
|
|
|
def create_optimizer(self) -> Optimizer:
|
|
|
|
"""Create an optimizer for the pipeline component.
|
|
|
|
|
|
|
|
RETURNS (thinc.api.Optimizer): The optimizer.
|
|
|
|
|
2021-01-30 12:09:38 +03:00
|
|
|
DOCS: https://spacy.io/api/pipe#create_optimizer
|
2020-10-08 22:33:49 +03:00
|
|
|
"""
|
|
|
|
return util.create_default_optimizer()
|
|
|
|
|
2023-07-19 13:03:31 +03:00
|
|
|
def initialize(self, get_examples: Callable[[], Iterable[Example]], *, nlp: Language = None):
|
2020-10-08 22:33:49 +03:00
|
|
|
"""Initialize the pipe for training, using data examples if available.
|
|
|
|
This method needs to be implemented by each TrainablePipe component,
|
|
|
|
ensuring the internal model (if available) is initialized properly
|
|
|
|
using the provided sample of Example objects.
|
|
|
|
|
|
|
|
get_examples (Callable[[], Iterable[Example]]): Function that
|
|
|
|
returns a representative sample of gold-standard Example objects.
|
|
|
|
nlp (Language): The current nlp object the component is part of.
|
|
|
|
|
2021-01-30 12:09:38 +03:00
|
|
|
DOCS: https://spacy.io/api/pipe#initialize
|
2020-10-08 22:33:49 +03:00
|
|
|
"""
|
|
|
|
raise NotImplementedError(Errors.E931.format(parent="TrainablePipe", method="initialize", name=self.name))
|
|
|
|
|
|
|
|
def add_label(self, label: str) -> int:
|
|
|
|
"""Add an output label.
|
|
|
|
For TrainablePipe components, it is possible to
|
|
|
|
extend pretrained models with new labels, but care should be taken to
|
|
|
|
avoid the "catastrophic forgetting" problem.
|
|
|
|
|
|
|
|
label (str): The label to add.
|
|
|
|
RETURNS (int): 0 if label is already present, otherwise 1.
|
|
|
|
|
2021-01-30 12:09:38 +03:00
|
|
|
DOCS: https://spacy.io/api/pipe#add_label
|
2020-10-08 22:33:49 +03:00
|
|
|
"""
|
|
|
|
raise NotImplementedError(Errors.E931.format(parent="Pipe", method="add_label", name=self.name))
|
|
|
|
|
2023-01-16 12:25:53 +03:00
|
|
|
@property
|
|
|
|
def is_distillable(self) -> bool:
|
|
|
|
# Normally a pipe overrides `get_teacher_student_loss` to implement
|
|
|
|
# distillation. In more exceptional cases, a pipe can provide its
|
|
|
|
# own `distill` implementation. If neither of these methods is
|
|
|
|
# overridden, the pipe does not implement distillation.
|
|
|
|
return not (self.__class__.distill is TrainablePipe.distill and self.__class__.get_teacher_student_loss is TrainablePipe.get_teacher_student_loss)
|
|
|
|
|
2020-10-08 22:33:49 +03:00
|
|
|
@property
|
|
|
|
def is_trainable(self) -> bool:
|
|
|
|
return True
|
|
|
|
|
|
|
|
@property
|
|
|
|
def is_resizable(self) -> bool:
|
|
|
|
return getattr(self, "model", None) and "resize_output" in self.model.attrs
|
|
|
|
|
|
|
|
def _allow_extra_label(self) -> None:
|
|
|
|
"""Raise an error if the component can not add any more labels."""
|
2021-06-16 12:45:00 +03:00
|
|
|
nO = None
|
|
|
|
if self.model.has_dim("nO"):
|
|
|
|
nO = self.model.get_dim("nO")
|
|
|
|
elif self.model.has_ref("output_layer") and self.model.get_ref("output_layer").has_dim("nO"):
|
|
|
|
nO = self.model.get_ref("output_layer").get_dim("nO")
|
|
|
|
if nO is not None and nO == len(self.labels):
|
2020-10-08 22:33:49 +03:00
|
|
|
if not self.is_resizable:
|
|
|
|
raise ValueError(Errors.E922.format(name=self.name, nO=self.model.get_dim("nO")))
|
|
|
|
|
|
|
|
def set_output(self, nO: int) -> None:
|
|
|
|
if self.is_resizable:
|
|
|
|
self.model.attrs["resize_output"](self.model, nO)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError(Errors.E921)
|
|
|
|
|
|
|
|
def use_params(self, params: dict):
|
|
|
|
"""Modify the pipe's model, to use the given parameter values. At the
|
|
|
|
end of the context, the original parameters are restored.
|
|
|
|
|
|
|
|
params (dict): The parameter values to use in the model.
|
|
|
|
|
2021-01-30 12:09:38 +03:00
|
|
|
DOCS: https://spacy.io/api/pipe#use_params
|
2020-10-08 22:33:49 +03:00
|
|
|
"""
|
|
|
|
with self.model.use_params(params):
|
|
|
|
yield
|
|
|
|
|
|
|
|
def finish_update(self, sgd: Optimizer) -> None:
|
|
|
|
"""Update parameters using the current parameter gradients.
|
|
|
|
The Optimizer instance contains the functionality to perform
|
|
|
|
the stochastic gradient descent.
|
|
|
|
|
|
|
|
sgd (thinc.api.Optimizer): The optimizer.
|
|
|
|
|
2021-01-30 12:09:38 +03:00
|
|
|
DOCS: https://spacy.io/api/pipe#finish_update
|
2020-10-08 22:33:49 +03:00
|
|
|
"""
|
|
|
|
self.model.finish_update(sgd)
|
|
|
|
|
2020-10-10 19:55:07 +03:00
|
|
|
def _validate_serialization_attrs(self):
|
|
|
|
"""Check that the pipe implements the required attributes. If a subclass
|
|
|
|
implements a custom __init__ method but doesn't set these attributes,
|
2020-10-10 21:59:48 +03:00
|
|
|
they currently default to None, so we need to perform additonal checks.
|
2020-10-10 19:55:07 +03:00
|
|
|
"""
|
|
|
|
if not hasattr(self, "vocab") or self.vocab is None:
|
|
|
|
raise ValueError(Errors.E899.format(name=util.get_object_name(self)))
|
|
|
|
if not hasattr(self, "model") or self.model is None:
|
|
|
|
raise ValueError(Errors.E898.format(name=util.get_object_name(self)))
|
|
|
|
|
2020-10-08 22:33:49 +03:00
|
|
|
def to_bytes(self, *, exclude=tuple()):
|
|
|
|
"""Serialize the pipe to a bytestring.
|
|
|
|
|
|
|
|
exclude (Iterable[str]): String names of serialization fields to exclude.
|
|
|
|
RETURNS (bytes): The serialized object.
|
|
|
|
|
2021-01-30 12:09:38 +03:00
|
|
|
DOCS: https://spacy.io/api/pipe#to_bytes
|
2020-10-08 22:33:49 +03:00
|
|
|
"""
|
2020-10-10 19:55:07 +03:00
|
|
|
self._validate_serialization_attrs()
|
2020-10-08 22:33:49 +03:00
|
|
|
serialize = {}
|
2020-10-10 19:55:07 +03:00
|
|
|
if hasattr(self, "cfg") and self.cfg is not None:
|
2020-10-08 22:33:49 +03:00
|
|
|
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
|
2021-08-03 15:42:44 +03:00
|
|
|
serialize["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude)
|
2020-10-08 22:33:49 +03:00
|
|
|
serialize["model"] = self.model.to_bytes
|
|
|
|
return util.to_bytes(serialize, exclude)
|
|
|
|
|
|
|
|
def from_bytes(self, bytes_data, *, exclude=tuple()):
|
|
|
|
"""Load the pipe from a bytestring.
|
|
|
|
|
|
|
|
exclude (Iterable[str]): String names of serialization fields to exclude.
|
|
|
|
RETURNS (TrainablePipe): The loaded object.
|
|
|
|
|
2021-01-30 12:09:38 +03:00
|
|
|
DOCS: https://spacy.io/api/pipe#from_bytes
|
2020-10-08 22:33:49 +03:00
|
|
|
"""
|
2020-10-10 19:55:07 +03:00
|
|
|
self._validate_serialization_attrs()
|
2020-10-08 22:33:49 +03:00
|
|
|
|
|
|
|
def load_model(b):
|
|
|
|
try:
|
|
|
|
self.model.from_bytes(b)
|
|
|
|
except AttributeError:
|
|
|
|
raise ValueError(Errors.E149) from None
|
|
|
|
|
|
|
|
deserialize = {}
|
2020-10-10 19:55:07 +03:00
|
|
|
if hasattr(self, "cfg") and self.cfg is not None:
|
2020-10-08 22:33:49 +03:00
|
|
|
deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
|
2021-08-03 15:42:44 +03:00
|
|
|
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b, exclude=exclude)
|
2020-10-08 22:33:49 +03:00
|
|
|
deserialize["model"] = load_model
|
|
|
|
util.from_bytes(bytes_data, deserialize, exclude)
|
|
|
|
return self
|
|
|
|
|
|
|
|
def to_disk(self, path, *, exclude=tuple()):
|
|
|
|
"""Serialize the pipe to disk.
|
|
|
|
|
|
|
|
path (str / Path): Path to a directory.
|
|
|
|
exclude (Iterable[str]): String names of serialization fields to exclude.
|
|
|
|
|
2021-01-30 12:09:38 +03:00
|
|
|
DOCS: https://spacy.io/api/pipe#to_disk
|
2020-10-08 22:33:49 +03:00
|
|
|
"""
|
2020-10-10 19:55:07 +03:00
|
|
|
self._validate_serialization_attrs()
|
2020-10-08 22:33:49 +03:00
|
|
|
serialize = {}
|
2020-10-10 19:55:07 +03:00
|
|
|
if hasattr(self, "cfg") and self.cfg is not None:
|
2020-10-08 22:33:49 +03:00
|
|
|
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
|
2021-08-03 15:42:44 +03:00
|
|
|
serialize["vocab"] = lambda p: self.vocab.to_disk(p, exclude=exclude)
|
2020-10-08 22:33:49 +03:00
|
|
|
serialize["model"] = lambda p: self.model.to_disk(p)
|
|
|
|
util.to_disk(path, serialize, exclude)
|
|
|
|
|
|
|
|
def from_disk(self, path, *, exclude=tuple()):
|
|
|
|
"""Load the pipe from disk.
|
|
|
|
|
|
|
|
path (str / Path): Path to a directory.
|
|
|
|
exclude (Iterable[str]): String names of serialization fields to exclude.
|
|
|
|
RETURNS (TrainablePipe): The loaded object.
|
|
|
|
|
2021-01-30 12:09:38 +03:00
|
|
|
DOCS: https://spacy.io/api/pipe#from_disk
|
2020-10-08 22:33:49 +03:00
|
|
|
"""
|
2020-10-10 19:55:07 +03:00
|
|
|
self._validate_serialization_attrs()
|
2020-10-08 22:33:49 +03:00
|
|
|
|
|
|
|
def load_model(p):
|
|
|
|
try:
|
2021-05-31 10:36:17 +03:00
|
|
|
with open(p, "rb") as mfile:
|
|
|
|
self.model.from_bytes(mfile.read())
|
2020-10-08 22:33:49 +03:00
|
|
|
except AttributeError:
|
|
|
|
raise ValueError(Errors.E149) from None
|
|
|
|
|
|
|
|
deserialize = {}
|
2020-10-10 19:55:07 +03:00
|
|
|
if hasattr(self, "cfg") and self.cfg is not None:
|
2020-10-08 22:33:49 +03:00
|
|
|
deserialize["cfg"] = lambda p: self.cfg.update(deserialize_config(p))
|
2021-08-03 15:42:44 +03:00
|
|
|
deserialize["vocab"] = lambda p: self.vocab.from_disk(p, exclude=exclude)
|
2020-10-08 22:33:49 +03:00
|
|
|
deserialize["model"] = load_model
|
|
|
|
util.from_disk(path, deserialize, exclude)
|
|
|
|
return self
|
Store activations in `Doc`s when `save_activations` is enabled (#11002)
* Store activations in Doc when `store_activations` is enabled
This change adds the new `activations` attribute to `Doc`. This
attribute can be used by trainable pipes to store their activations,
probabilities, and guesses for downstream users.
As an example, this change modifies the `tagger` and `senter` pipes to
add an `store_activations` option. When this option is enabled, the
probabilities and guesses are stored in `set_annotations`.
* Change type of `store_activations` to `Union[bool, List[str]]`
When the value is:
- A bool: all activations are stored when set to `True`.
- A List[str]: the activations named in the list are stored
* Formatting fixes in Tagger
* Support store_activations in spancat and morphologizer
* Make Doc.activations type visible to MyPy
* textcat/textcat_multilabel: add store_activations option
* trainable_lemmatizer/entity_linker: add store_activations option
* parser/ner: do not currently support returning activations
* Extend tagger and senter tests
So that they, like the other tests, also check that we get no
activations if no activations were requested.
* Document `Doc.activations` and `store_activations` in the relevant pipes
* Start errors/warnings at higher numbers to avoid merge conflicts
Between the master and v4 branches.
* Add `store_activations` to docstrings.
* Replace store_activations setter by set_store_activations method
Setters that take a different type than what the getter returns are still
problematic for MyPy. Replace the setter by a method, so that type inference
works everywhere.
* Use dict comprehension suggested by @svlandeg
* Revert "Use dict comprehension suggested by @svlandeg"
This reverts commit 6e7b958f7060397965176c69649e5414f1f24988.
* EntityLinker: add type annotations to _add_activations
* _store_activations: make kwarg-only, remove doc_scores_lens arg
* set_annotations: add type annotations
* Apply suggestions from code review
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
* TextCat.predict: return dict
* Make the `TrainablePipe.store_activations` property a bool
This means that we can also bring back `store_activations` setter.
* Remove `TrainablePipe.activations`
We do not need to enumerate the activations anymore since `store_activations` is
`bool`.
* Add type annotations for activations in predict/set_annotations
* Rename `TrainablePipe.store_activations` to `save_activations`
* Error E1400 is not used anymore
This error was used when activations were still `Union[bool, List[str]]`.
* Change wording in API docs after store -> save change
* docs: tag (save_)activations as new in spaCy 4.0
* Fix copied line in morphologizer activations test
* Don't train in any test_save_activations test
* Rename activations
- "probs" -> "probabilities"
- "guesses" -> "label_ids", except in the edit tree lemmatizer, where
"guesses" -> "tree_ids".
* Remove unused W400 warning.
This warning was used when we still allowed the user to specify
which activations to save.
* Formatting fixes
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
* Replace "kb_ids" by a constant
* spancat: replace a cast by an assertion
* Fix EOF spacing
* Fix comments in test_save_activations tests
* Do not set RNG seed in activation saving tests
* Revert "spancat: replace a cast by an assertion"
This reverts commit 0bd5730d16432443a2b247316928d4f789ad8741.
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
2022-09-13 10:51:12 +03:00
|
|
|
|
|
|
|
@property
|
|
|
|
def save_activations(self):
|
|
|
|
return self._save_activations
|
|
|
|
|
|
|
|
@save_activations.setter
|
|
|
|
def save_activations(self, save_activations: bool):
|
|
|
|
self._save_activations = save_activations
|