mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-04 12:20:20 +03:00
TrainablePipe.distill: use Iterable[Example]
This commit is contained in:
parent
a4196eddc5
commit
a9ed400fd1
|
@ -59,8 +59,7 @@ cdef class TrainablePipe(Pipe):
|
|||
|
||||
def distill(self,
|
||||
teacher_pipe: Optional["TrainablePipe"],
|
||||
teacher_docs: Iterable["Doc"],
|
||||
student_docs: Iterable["Doc"],
|
||||
examples: Iterable["Example"],
|
||||
*,
|
||||
drop: float=0.0,
|
||||
sgd: Optional[Optimizer]=None,
|
||||
|
@ -71,10 +70,8 @@ cdef class TrainablePipe(Pipe):
|
|||
|
||||
teacher_pipe (Optional[TrainablePipe]): The teacher pipe to learn
|
||||
from.
|
||||
teacher_docs (Iterable[Doc]): Documents passed through teacher pipes.
|
||||
student_docs (Iterable[Doc]): Documents passed through student pipes.
|
||||
Must contain the same tokens as `teacher_docs` but may have
|
||||
different annotations.
|
||||
examples (Iterable[Example]): Distillation examples. The reference
|
||||
must contain teacher annotations (if any).
|
||||
drop (float): dropout rate.
|
||||
sgd (Optional[Optimizer]): An optimizer. Will be created via
|
||||
create_optimizer if not set.
|
||||
|
@ -89,16 +86,13 @@ cdef class TrainablePipe(Pipe):
|
|||
if losses is None:
|
||||
losses = {}
|
||||
losses.setdefault(self.name, 0.0)
|
||||
if not any(len(doc) for doc in teacher_docs):
|
||||
return losses
|
||||
if not any(len(doc) for doc in student_docs):
|
||||
return losses
|
||||
validate_examples(examples, "TrainablePipe.distill")
|
||||
set_dropout_rate(self.model, drop)
|
||||
for node in teacher_pipe.model.walk():
|
||||
if node.name == "softmax":
|
||||
node.attrs["softmax_normalize"] = True
|
||||
teacher_scores = teacher_pipe.model.predict(teacher_docs)
|
||||
student_scores, bp_student_scores = self.model.begin_update(student_docs)
|
||||
teacher_scores = teacher_pipe.model.predict([eg.reference for eg in examples])
|
||||
student_scores, bp_student_scores = self.model.begin_update([eg.predicted for eg in examples])
|
||||
loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores)
|
||||
bp_student_scores(d_scores)
|
||||
if sgd is not None:
|
||||
|
|
|
@ -209,8 +209,7 @@ cdef class Parser(TrainablePipe):
|
|||
|
||||
def distill(self,
|
||||
teacher_pipe: Optional[TrainablePipe],
|
||||
teacher_docs: Iterable[Doc],
|
||||
student_docs: Iterable[Doc],
|
||||
examples: Iterable["Example"],
|
||||
*,
|
||||
drop: float=0.0,
|
||||
sgd: Optional[Optimizer]=None,
|
||||
|
@ -221,11 +220,8 @@ cdef class Parser(TrainablePipe):
|
|||
|
||||
teacher_pipe (Optional[TrainablePipe]): The teacher pipe to learn
|
||||
from.
|
||||
teacher_docs (Iterable[Doc]): Documents passed through teacher pipes.
|
||||
student_docs (Iterable[Doc]): Documents passed through student pipes.
|
||||
Must contain the same tokens as `teacher_docs` but may have
|
||||
different annotations.
|
||||
drop (float): dropout rate.
|
||||
examples (Iterable[Example]): Distillation examples. The reference
|
||||
must contain teacher annotations (if any).
|
||||
sgd (Optional[Optimizer]): An optimizer. Will be created via
|
||||
create_optimizer if not set.
|
||||
losses (Optional[Dict[str, float]]): Optional record of loss during
|
||||
|
@ -238,14 +234,13 @@ cdef class Parser(TrainablePipe):
|
|||
losses = {}
|
||||
losses.setdefault(self.name, 0.0)
|
||||
|
||||
if not any(len(doc) for doc in teacher_docs):
|
||||
return losses
|
||||
if not any(len(doc) for doc in student_docs):
|
||||
return losses
|
||||
validate_examples(examples, "TransitionParser.distill")
|
||||
|
||||
set_dropout_rate(self.model, drop)
|
||||
|
||||
teacher_step_model = teacher_pipe.model.predict(teacher_docs)
|
||||
student_docs = [eg.predicted for eg in examples]
|
||||
|
||||
teacher_step_model = teacher_pipe.model.predict([eg.reference for eg in examples])
|
||||
student_step_model, backprop_tok2vec = self.model.begin_update(student_docs)
|
||||
|
||||
# Add softmax activation, so that we can compute student losses
|
||||
|
|
|
@ -639,11 +639,13 @@ def test_distill():
|
|||
get_examples=lambda: train_examples, labels=teacher_ner.label_data
|
||||
)
|
||||
|
||||
docs = [eg.predicted for eg in train_examples]
|
||||
distill_examples = [
|
||||
Example.from_dict(teacher.make_doc(t[0]), {}) for t in TRAIN_DATA
|
||||
]
|
||||
|
||||
for i in range(100):
|
||||
losses = {}
|
||||
student_ner.distill(teacher_ner, docs, docs, sgd=optimizer, losses=losses)
|
||||
student_ner.distill(teacher_ner, distill_examples, sgd=optimizer, losses=losses)
|
||||
assert losses["ner"] < 0.0001
|
||||
|
||||
# test the trained model
|
||||
|
|
|
@ -418,11 +418,15 @@ def test_distill():
|
|||
get_examples=lambda: train_examples, labels=teacher_parser.label_data
|
||||
)
|
||||
|
||||
docs = [eg.predicted for eg in train_examples]
|
||||
distill_examples = [
|
||||
Example.from_dict(teacher.make_doc(t[0]), {}) for t in TRAIN_DATA
|
||||
]
|
||||
|
||||
for i in range(200):
|
||||
losses = {}
|
||||
student_parser.distill(teacher_parser, docs, docs, sgd=optimizer, losses=losses)
|
||||
student_parser.distill(
|
||||
teacher_parser, distill_examples, sgd=optimizer, losses=losses
|
||||
)
|
||||
assert losses["parser"] < 0.0001
|
||||
|
||||
test_text = "I like securities."
|
||||
|
|
|
@ -217,12 +217,14 @@ def test_distill():
|
|||
get_examples=lambda: train_examples, labels=teacher_lemmatizer.label_data
|
||||
)
|
||||
|
||||
docs = [eg.predicted for eg in train_examples]
|
||||
distill_examples = [
|
||||
Example.from_dict(teacher.make_doc(t[0]), {}) for t in TRAIN_DATA
|
||||
]
|
||||
|
||||
for i in range(50):
|
||||
losses = {}
|
||||
student_lemmatizer.distill(
|
||||
teacher_lemmatizer, docs, docs, sgd=optimizer, losses=losses
|
||||
teacher_lemmatizer, distill_examples, sgd=optimizer, losses=losses
|
||||
)
|
||||
assert losses["trainable_lemmatizer"] < 0.00001
|
||||
|
||||
|
|
|
@ -234,11 +234,15 @@ def test_distill():
|
|||
get_examples=lambda: train_examples, labels=teacher_tagger.label_data
|
||||
)
|
||||
|
||||
docs = [eg.predicted for eg in train_examples]
|
||||
distill_examples = [
|
||||
Example.from_dict(teacher.make_doc(t[0]), {}) for t in TRAIN_DATA
|
||||
]
|
||||
|
||||
for i in range(50):
|
||||
losses = {}
|
||||
student_tagger.distill(teacher_tagger, docs, docs, sgd=optimizer, losses=losses)
|
||||
student_tagger.distill(
|
||||
teacher_tagger, distill_examples, sgd=optimizer, losses=losses
|
||||
)
|
||||
assert losses["tagger"] < 0.00001
|
||||
|
||||
test_text = "I like blue eggs"
|
||||
|
|
Loading…
Reference in New Issue
Block a user