TrainablePipe.distill: use Iterable[Example]

This commit is contained in:
Daniël de Kok 2023-01-12 16:29:14 +01:00
parent a4196eddc5
commit a9ed400fd1
6 changed files with 33 additions and 32 deletions

View File

@ -59,8 +59,7 @@ cdef class TrainablePipe(Pipe):
def distill(self,
teacher_pipe: Optional["TrainablePipe"],
teacher_docs: Iterable["Doc"],
student_docs: Iterable["Doc"],
examples: Iterable["Example"],
*,
drop: float=0.0,
sgd: Optional[Optimizer]=None,
@ -71,10 +70,8 @@ cdef class TrainablePipe(Pipe):
teacher_pipe (Optional[TrainablePipe]): The teacher pipe to learn
from.
teacher_docs (Iterable[Doc]): Documents passed through teacher pipes.
student_docs (Iterable[Doc]): Documents passed through student pipes.
Must contain the same tokens as `teacher_docs` but may have
different annotations.
examples (Iterable[Example]): Distillation examples. The reference
must contain teacher annotations (if any).
drop (float): dropout rate.
sgd (Optional[Optimizer]): An optimizer. Will be created via
create_optimizer if not set.
@ -89,16 +86,13 @@ cdef class TrainablePipe(Pipe):
if losses is None:
losses = {}
losses.setdefault(self.name, 0.0)
if not any(len(doc) for doc in teacher_docs):
return losses
if not any(len(doc) for doc in student_docs):
return losses
validate_examples(examples, "TrainablePipe.distill")
set_dropout_rate(self.model, drop)
for node in teacher_pipe.model.walk():
if node.name == "softmax":
node.attrs["softmax_normalize"] = True
teacher_scores = teacher_pipe.model.predict(teacher_docs)
student_scores, bp_student_scores = self.model.begin_update(student_docs)
teacher_scores = teacher_pipe.model.predict([eg.reference for eg in examples])
student_scores, bp_student_scores = self.model.begin_update([eg.predicted for eg in examples])
loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores)
bp_student_scores(d_scores)
if sgd is not None:

View File

@ -209,8 +209,7 @@ cdef class Parser(TrainablePipe):
def distill(self,
teacher_pipe: Optional[TrainablePipe],
teacher_docs: Iterable[Doc],
student_docs: Iterable[Doc],
examples: Iterable["Example"],
*,
drop: float=0.0,
sgd: Optional[Optimizer]=None,
@ -221,11 +220,8 @@ cdef class Parser(TrainablePipe):
teacher_pipe (Optional[TrainablePipe]): The teacher pipe to learn
from.
teacher_docs (Iterable[Doc]): Documents passed through teacher pipes.
student_docs (Iterable[Doc]): Documents passed through student pipes.
Must contain the same tokens as `teacher_docs` but may have
different annotations.
drop (float): dropout rate.
examples (Iterable[Example]): Distillation examples. The reference
must contain teacher annotations (if any).
sgd (Optional[Optimizer]): An optimizer. Will be created via
create_optimizer if not set.
losses (Optional[Dict[str, float]]): Optional record of loss during
@ -238,14 +234,13 @@ cdef class Parser(TrainablePipe):
losses = {}
losses.setdefault(self.name, 0.0)
if not any(len(doc) for doc in teacher_docs):
return losses
if not any(len(doc) for doc in student_docs):
return losses
validate_examples(examples, "TransitionParser.distill")
set_dropout_rate(self.model, drop)
teacher_step_model = teacher_pipe.model.predict(teacher_docs)
student_docs = [eg.predicted for eg in examples]
teacher_step_model = teacher_pipe.model.predict([eg.reference for eg in examples])
student_step_model, backprop_tok2vec = self.model.begin_update(student_docs)
# Add softmax activation, so that we can compute student losses

View File

@ -639,11 +639,13 @@ def test_distill():
get_examples=lambda: train_examples, labels=teacher_ner.label_data
)
docs = [eg.predicted for eg in train_examples]
distill_examples = [
Example.from_dict(teacher.make_doc(t[0]), {}) for t in TRAIN_DATA
]
for i in range(100):
losses = {}
student_ner.distill(teacher_ner, docs, docs, sgd=optimizer, losses=losses)
student_ner.distill(teacher_ner, distill_examples, sgd=optimizer, losses=losses)
assert losses["ner"] < 0.0001
# test the trained model

View File

@ -418,11 +418,15 @@ def test_distill():
get_examples=lambda: train_examples, labels=teacher_parser.label_data
)
docs = [eg.predicted for eg in train_examples]
distill_examples = [
Example.from_dict(teacher.make_doc(t[0]), {}) for t in TRAIN_DATA
]
for i in range(200):
losses = {}
student_parser.distill(teacher_parser, docs, docs, sgd=optimizer, losses=losses)
student_parser.distill(
teacher_parser, distill_examples, sgd=optimizer, losses=losses
)
assert losses["parser"] < 0.0001
test_text = "I like securities."

View File

@ -217,12 +217,14 @@ def test_distill():
get_examples=lambda: train_examples, labels=teacher_lemmatizer.label_data
)
docs = [eg.predicted for eg in train_examples]
distill_examples = [
Example.from_dict(teacher.make_doc(t[0]), {}) for t in TRAIN_DATA
]
for i in range(50):
losses = {}
student_lemmatizer.distill(
teacher_lemmatizer, docs, docs, sgd=optimizer, losses=losses
teacher_lemmatizer, distill_examples, sgd=optimizer, losses=losses
)
assert losses["trainable_lemmatizer"] < 0.00001

View File

@ -234,11 +234,15 @@ def test_distill():
get_examples=lambda: train_examples, labels=teacher_tagger.label_data
)
docs = [eg.predicted for eg in train_examples]
distill_examples = [
Example.from_dict(teacher.make_doc(t[0]), {}) for t in TRAIN_DATA
]
for i in range(50):
losses = {}
student_tagger.distill(teacher_tagger, docs, docs, sgd=optimizer, losses=losses)
student_tagger.distill(
teacher_tagger, distill_examples, sgd=optimizer, losses=losses
)
assert losses["tagger"] < 0.00001
test_text = "I like blue eggs"