mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
Tok2Vec
: Add distill
method (#12108)
* `Tok2Vec`: Add `distill` method * `Tok2Vec`: Refactor `update` * Add `Tok2Vec.distill` test * Update `distill` signature to accept `Example`s instead of separate teacher and student docs * Add docs * Remove docstring * Update test * Remove `update` calls from test * Update `Tok2Vec.distill` docstring
This commit is contained in:
parent
41b3a0d932
commit
520279ff7c
|
@ -1,5 +1,6 @@
|
|||
from typing import Sequence, Iterable, Optional, Dict, Callable, List, Any
|
||||
from typing import Sequence, Iterable, Optional, Dict, Callable, List, Any, Tuple
|
||||
from thinc.api import Model, set_dropout_rate, Optimizer, Config
|
||||
from thinc.types import Floats2d
|
||||
from itertools import islice
|
||||
|
||||
from .trainable_pipe import TrainablePipe
|
||||
|
@ -157,39 +158,9 @@ class Tok2Vec(TrainablePipe):
|
|||
|
||||
DOCS: https://spacy.io/api/tok2vec#update
|
||||
"""
|
||||
if losses is None:
|
||||
losses = {}
|
||||
validate_examples(examples, "Tok2Vec.update")
|
||||
docs = [eg.predicted for eg in examples]
|
||||
set_dropout_rate(self.model, drop)
|
||||
tokvecs, bp_tokvecs = self.model.begin_update(docs)
|
||||
d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
|
||||
losses.setdefault(self.name, 0.0)
|
||||
|
||||
def accumulate_gradient(one_d_tokvecs):
|
||||
"""Accumulate tok2vec loss and gradient. This is passed as a callback
|
||||
to all but the last listener. Only the last one does the backprop.
|
||||
"""
|
||||
nonlocal d_tokvecs
|
||||
for i in range(len(one_d_tokvecs)):
|
||||
d_tokvecs[i] += one_d_tokvecs[i]
|
||||
losses[self.name] += float((one_d_tokvecs[i] ** 2).sum())
|
||||
return [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
|
||||
|
||||
def backprop(one_d_tokvecs):
|
||||
"""Callback to actually do the backprop. Passed to last listener."""
|
||||
accumulate_gradient(one_d_tokvecs)
|
||||
d_docs = bp_tokvecs(d_tokvecs)
|
||||
if sgd is not None:
|
||||
self.finish_update(sgd)
|
||||
return d_docs
|
||||
|
||||
batch_id = Tok2VecListener.get_batch_id(docs)
|
||||
for listener in self.listeners[:-1]:
|
||||
listener.receive(batch_id, tokvecs, accumulate_gradient)
|
||||
if self.listeners:
|
||||
self.listeners[-1].receive(batch_id, tokvecs, backprop)
|
||||
return losses
|
||||
return self._update_with_docs(docs, drop=drop, sgd=sgd, losses=losses)
|
||||
|
||||
def get_loss(self, examples, scores) -> None:
|
||||
pass
|
||||
|
@ -219,6 +190,96 @@ class Tok2Vec(TrainablePipe):
|
|||
def add_label(self, label):
|
||||
raise NotImplementedError
|
||||
|
||||
def distill(
|
||||
self,
|
||||
teacher_pipe: Optional["TrainablePipe"],
|
||||
examples: Iterable["Example"],
|
||||
*,
|
||||
drop: float = 0.0,
|
||||
sgd: Optional[Optimizer] = None,
|
||||
losses: Optional[Dict[str, float]] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""Performs an update of the student pipe's model using the
|
||||
student's distillation examples and sets the annotations
|
||||
of the teacher's distillation examples using the teacher pipe.
|
||||
|
||||
teacher_pipe (Optional[TrainablePipe]): The teacher pipe to use
|
||||
for prediction.
|
||||
examples (Iterable[Example]): Distillation examples. The reference (teacher)
|
||||
and predicted (student) docs must have the same number of tokens and the
|
||||
same orthography.
|
||||
drop (float): dropout rate.
|
||||
sgd (Optional[Optimizer]): An optimizer. Will be created via
|
||||
create_optimizer if not set.
|
||||
losses (Optional[Dict[str, float]]): Optional record of loss during
|
||||
distillation.
|
||||
RETURNS: The updated losses dictionary.
|
||||
|
||||
DOCS: https://spacy.io/api/tok2vec#distill
|
||||
"""
|
||||
# By default we require a teacher pipe, but there are downstream
|
||||
# implementations that don't require a pipe.
|
||||
if teacher_pipe is None:
|
||||
raise ValueError(Errors.E4002.format(name=self.name))
|
||||
teacher_docs = [eg.reference for eg in examples]
|
||||
student_docs = [eg.predicted for eg in examples]
|
||||
teacher_preds = teacher_pipe.predict(teacher_docs)
|
||||
teacher_pipe.set_annotations(teacher_docs, teacher_preds)
|
||||
return self._update_with_docs(student_docs, drop=drop, sgd=sgd, losses=losses)
|
||||
|
||||
def _update_with_docs(
|
||||
self,
|
||||
docs: Iterable[Doc],
|
||||
*,
|
||||
drop: float = 0.0,
|
||||
sgd: Optional[Optimizer] = None,
|
||||
losses: Optional[Dict[str, float]] = None,
|
||||
):
|
||||
if losses is None:
|
||||
losses = {}
|
||||
losses.setdefault(self.name, 0.0)
|
||||
set_dropout_rate(self.model, drop)
|
||||
|
||||
tokvecs, accumulate_gradient, backprop = self._create_backprops(
|
||||
docs, losses, sgd=sgd
|
||||
)
|
||||
batch_id = Tok2VecListener.get_batch_id(docs)
|
||||
for listener in self.listeners[:-1]:
|
||||
listener.receive(batch_id, tokvecs, accumulate_gradient)
|
||||
if self.listeners:
|
||||
self.listeners[-1].receive(batch_id, tokvecs, backprop)
|
||||
return losses
|
||||
|
||||
def _create_backprops(
|
||||
self,
|
||||
docs: Iterable[Doc],
|
||||
losses: Dict[str, float],
|
||||
*,
|
||||
sgd: Optional[Optimizer] = None,
|
||||
) -> Tuple[Floats2d, Callable, Callable]:
|
||||
tokvecs, bp_tokvecs = self.model.begin_update(docs)
|
||||
d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
|
||||
|
||||
def accumulate_gradient(one_d_tokvecs):
|
||||
"""Accumulate tok2vec loss and gradient. This is passed as a callback
|
||||
to all but the last listener. Only the last one does the backprop.
|
||||
"""
|
||||
nonlocal d_tokvecs
|
||||
for i in range(len(one_d_tokvecs)):
|
||||
d_tokvecs[i] += one_d_tokvecs[i]
|
||||
losses[self.name] += float((one_d_tokvecs[i] ** 2).sum())
|
||||
return [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
|
||||
|
||||
def backprop(one_d_tokvecs):
|
||||
"""Callback to actually do the backprop. Passed to last listener."""
|
||||
accumulate_gradient(one_d_tokvecs)
|
||||
d_docs = bp_tokvecs(d_tokvecs)
|
||||
if sgd is not None:
|
||||
self.finish_update(sgd)
|
||||
return d_docs
|
||||
|
||||
return tokvecs, accumulate_gradient, backprop
|
||||
|
||||
|
||||
class Tok2VecListener(Model):
|
||||
"""A layer that gets fed its answers from an upstream connection,
|
||||
|
|
|
@ -540,3 +540,86 @@ def test_tok2vec_listeners_textcat():
|
|||
assert cats1["imperative"] < 0.9
|
||||
assert [t.tag_ for t in docs[0]] == ["V", "J", "N"]
|
||||
assert [t.tag_ for t in docs[1]] == ["N", "V", "J", "N"]
|
||||
|
||||
|
||||
cfg_string_distillation = """
|
||||
[nlp]
|
||||
lang = "en"
|
||||
pipeline = ["tok2vec","tagger"]
|
||||
|
||||
[components]
|
||||
|
||||
[components.tagger]
|
||||
factory = "tagger"
|
||||
|
||||
[components.tagger.model]
|
||||
@architectures = "spacy.Tagger.v2"
|
||||
nO = null
|
||||
|
||||
[components.tagger.model.tok2vec]
|
||||
@architectures = "spacy.Tok2VecListener.v1"
|
||||
width = ${components.tok2vec.model.encode.width}
|
||||
|
||||
[components.tok2vec]
|
||||
factory = "tok2vec"
|
||||
|
||||
[components.tok2vec.model]
|
||||
@architectures = "spacy.Tok2Vec.v2"
|
||||
|
||||
[components.tok2vec.model.embed]
|
||||
@architectures = "spacy.MultiHashEmbed.v2"
|
||||
width = ${components.tok2vec.model.encode.width}
|
||||
rows = [2000, 1000, 1000, 1000]
|
||||
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
|
||||
include_static_vectors = false
|
||||
|
||||
[components.tok2vec.model.encode]
|
||||
@architectures = "spacy.MaxoutWindowEncoder.v2"
|
||||
width = 96
|
||||
depth = 4
|
||||
window_size = 1
|
||||
maxout_pieces = 3
|
||||
"""
|
||||
|
||||
|
||||
def test_tok2vec_distillation_teacher_annotations():
|
||||
orig_config = Config().from_str(cfg_string_distillation)
|
||||
teacher_nlp = util.load_model_from_config(
|
||||
orig_config, auto_fill=True, validate=True
|
||||
)
|
||||
student_nlp = util.load_model_from_config(
|
||||
orig_config, auto_fill=True, validate=True
|
||||
)
|
||||
|
||||
train_examples_teacher = []
|
||||
train_examples_student = []
|
||||
for t in TRAIN_DATA:
|
||||
train_examples_teacher.append(
|
||||
Example.from_dict(teacher_nlp.make_doc(t[0]), t[1])
|
||||
)
|
||||
train_examples_student.append(
|
||||
Example.from_dict(student_nlp.make_doc(t[0]), t[1])
|
||||
)
|
||||
|
||||
optimizer = teacher_nlp.initialize(lambda: train_examples_teacher)
|
||||
student_nlp.initialize(lambda: train_examples_student)
|
||||
|
||||
# Since Language.distill creates a copy of the examples to use as
|
||||
# its internal teacher/student docs, we'll need to monkey-patch the
|
||||
# tok2vec pipe's distill method.
|
||||
student_tok2vec = student_nlp.get_pipe("tok2vec")
|
||||
student_tok2vec._old_distill = student_tok2vec.distill
|
||||
|
||||
def tok2vec_distill_wrapper(
|
||||
self,
|
||||
teacher_pipe,
|
||||
examples,
|
||||
**kwargs,
|
||||
):
|
||||
assert all(not eg.reference.tensor.any() for eg in examples)
|
||||
out = self._old_distill(teacher_pipe, examples, **kwargs)
|
||||
assert all(eg.reference.tensor.any() for eg in examples)
|
||||
return out
|
||||
|
||||
student_tok2vec.distill = tok2vec_distill_wrapper.__get__(student_tok2vec, Tok2Vec)
|
||||
student_nlp.distill(teacher_nlp, train_examples_student, sgd=optimizer, losses={})
|
||||
|
|
|
@ -100,6 +100,43 @@ pipeline components are applied to the `Doc` in order. Both
|
|||
| `doc` | The document to process. ~~Doc~~ |
|
||||
| **RETURNS** | The processed document. ~~Doc~~ |
|
||||
|
||||
## Tok2Vec.distill {id="distill", tag="method,experimental", version="4"}
|
||||
|
||||
Performs an update of the student pipe's model using the student's distillation
|
||||
examples and sets the annotations of the teacher's distillation examples using
|
||||
the teacher pipe.
|
||||
|
||||
Unlike other trainable pipes, the student pipe doesn't directly learn its
|
||||
representations from the teacher. However, since downstream pipes that do
|
||||
perform distillation expect the tok2vec annotations to be present on the
|
||||
correct distillation examples, we need to ensure that they are set beforehand.
|
||||
|
||||
The distillation is performed on ~~Example~~ objects. The `Example.reference`
|
||||
and `Example.predicted` ~~Doc~~s must have the same number of tokens and the
|
||||
same orthography. Even though the reference does not need have to have gold
|
||||
annotations, the teacher could adds its own annotations when necessary.
|
||||
|
||||
This feature is experimental.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> teacher_pipe = teacher.add_pipe("tok2vec")
|
||||
> student_pipe = student.add_pipe("tok2vec")
|
||||
> optimizer = nlp.resume_training()
|
||||
> losses = student.distill(teacher_pipe, examples, sgd=optimizer)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `teacher_pipe` | The teacher pipe to use for prediction. ~~Optional[TrainablePipe]~~ |
|
||||
| `examples` | Distillation examples. The reference (teacher) and predicted (student) docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ |
|
||||
| _keyword-only_ | |
|
||||
| `drop` | Dropout rate. ~~float~~ |
|
||||
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
|
||||
| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
|
||||
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
|
||||
|
||||
## Tok2Vec.pipe {id="pipe",tag="method"}
|
||||
|
||||
Apply the pipe to a stream of documents. This usually happens under the hood
|
||||
|
|
Loading…
Reference in New Issue
Block a user