mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-04 12:20:20 +03:00
TrainablePipe.distill: use Iterable[Example]
This commit is contained in:
parent
a4196eddc5
commit
a9ed400fd1
|
@ -59,8 +59,7 @@ cdef class TrainablePipe(Pipe):
|
||||||
|
|
||||||
def distill(self,
|
def distill(self,
|
||||||
teacher_pipe: Optional["TrainablePipe"],
|
teacher_pipe: Optional["TrainablePipe"],
|
||||||
teacher_docs: Iterable["Doc"],
|
examples: Iterable["Example"],
|
||||||
student_docs: Iterable["Doc"],
|
|
||||||
*,
|
*,
|
||||||
drop: float=0.0,
|
drop: float=0.0,
|
||||||
sgd: Optional[Optimizer]=None,
|
sgd: Optional[Optimizer]=None,
|
||||||
|
@ -71,10 +70,8 @@ cdef class TrainablePipe(Pipe):
|
||||||
|
|
||||||
teacher_pipe (Optional[TrainablePipe]): The teacher pipe to learn
|
teacher_pipe (Optional[TrainablePipe]): The teacher pipe to learn
|
||||||
from.
|
from.
|
||||||
teacher_docs (Iterable[Doc]): Documents passed through teacher pipes.
|
examples (Iterable[Example]): Distillation examples. The reference
|
||||||
student_docs (Iterable[Doc]): Documents passed through student pipes.
|
must contain teacher annotations (if any).
|
||||||
Must contain the same tokens as `teacher_docs` but may have
|
|
||||||
different annotations.
|
|
||||||
drop (float): dropout rate.
|
drop (float): dropout rate.
|
||||||
sgd (Optional[Optimizer]): An optimizer. Will be created via
|
sgd (Optional[Optimizer]): An optimizer. Will be created via
|
||||||
create_optimizer if not set.
|
create_optimizer if not set.
|
||||||
|
@ -89,16 +86,13 @@ cdef class TrainablePipe(Pipe):
|
||||||
if losses is None:
|
if losses is None:
|
||||||
losses = {}
|
losses = {}
|
||||||
losses.setdefault(self.name, 0.0)
|
losses.setdefault(self.name, 0.0)
|
||||||
if not any(len(doc) for doc in teacher_docs):
|
validate_examples(examples, "TrainablePipe.distill")
|
||||||
return losses
|
|
||||||
if not any(len(doc) for doc in student_docs):
|
|
||||||
return losses
|
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
for node in teacher_pipe.model.walk():
|
for node in teacher_pipe.model.walk():
|
||||||
if node.name == "softmax":
|
if node.name == "softmax":
|
||||||
node.attrs["softmax_normalize"] = True
|
node.attrs["softmax_normalize"] = True
|
||||||
teacher_scores = teacher_pipe.model.predict(teacher_docs)
|
teacher_scores = teacher_pipe.model.predict([eg.reference for eg in examples])
|
||||||
student_scores, bp_student_scores = self.model.begin_update(student_docs)
|
student_scores, bp_student_scores = self.model.begin_update([eg.predicted for eg in examples])
|
||||||
loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores)
|
loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores)
|
||||||
bp_student_scores(d_scores)
|
bp_student_scores(d_scores)
|
||||||
if sgd is not None:
|
if sgd is not None:
|
||||||
|
|
|
@ -209,8 +209,7 @@ cdef class Parser(TrainablePipe):
|
||||||
|
|
||||||
def distill(self,
|
def distill(self,
|
||||||
teacher_pipe: Optional[TrainablePipe],
|
teacher_pipe: Optional[TrainablePipe],
|
||||||
teacher_docs: Iterable[Doc],
|
examples: Iterable["Example"],
|
||||||
student_docs: Iterable[Doc],
|
|
||||||
*,
|
*,
|
||||||
drop: float=0.0,
|
drop: float=0.0,
|
||||||
sgd: Optional[Optimizer]=None,
|
sgd: Optional[Optimizer]=None,
|
||||||
|
@ -221,11 +220,8 @@ cdef class Parser(TrainablePipe):
|
||||||
|
|
||||||
teacher_pipe (Optional[TrainablePipe]): The teacher pipe to learn
|
teacher_pipe (Optional[TrainablePipe]): The teacher pipe to learn
|
||||||
from.
|
from.
|
||||||
teacher_docs (Iterable[Doc]): Documents passed through teacher pipes.
|
examples (Iterable[Example]): Distillation examples. The reference
|
||||||
student_docs (Iterable[Doc]): Documents passed through student pipes.
|
must contain teacher annotations (if any).
|
||||||
Must contain the same tokens as `teacher_docs` but may have
|
|
||||||
different annotations.
|
|
||||||
drop (float): dropout rate.
|
|
||||||
sgd (Optional[Optimizer]): An optimizer. Will be created via
|
sgd (Optional[Optimizer]): An optimizer. Will be created via
|
||||||
create_optimizer if not set.
|
create_optimizer if not set.
|
||||||
losses (Optional[Dict[str, float]]): Optional record of loss during
|
losses (Optional[Dict[str, float]]): Optional record of loss during
|
||||||
|
@ -238,14 +234,13 @@ cdef class Parser(TrainablePipe):
|
||||||
losses = {}
|
losses = {}
|
||||||
losses.setdefault(self.name, 0.0)
|
losses.setdefault(self.name, 0.0)
|
||||||
|
|
||||||
if not any(len(doc) for doc in teacher_docs):
|
validate_examples(examples, "TransitionParser.distill")
|
||||||
return losses
|
|
||||||
if not any(len(doc) for doc in student_docs):
|
|
||||||
return losses
|
|
||||||
|
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
|
|
||||||
teacher_step_model = teacher_pipe.model.predict(teacher_docs)
|
student_docs = [eg.predicted for eg in examples]
|
||||||
|
|
||||||
|
teacher_step_model = teacher_pipe.model.predict([eg.reference for eg in examples])
|
||||||
student_step_model, backprop_tok2vec = self.model.begin_update(student_docs)
|
student_step_model, backprop_tok2vec = self.model.begin_update(student_docs)
|
||||||
|
|
||||||
# Add softmax activation, so that we can compute student losses
|
# Add softmax activation, so that we can compute student losses
|
||||||
|
|
|
@ -639,11 +639,13 @@ def test_distill():
|
||||||
get_examples=lambda: train_examples, labels=teacher_ner.label_data
|
get_examples=lambda: train_examples, labels=teacher_ner.label_data
|
||||||
)
|
)
|
||||||
|
|
||||||
docs = [eg.predicted for eg in train_examples]
|
distill_examples = [
|
||||||
|
Example.from_dict(teacher.make_doc(t[0]), {}) for t in TRAIN_DATA
|
||||||
|
]
|
||||||
|
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
losses = {}
|
losses = {}
|
||||||
student_ner.distill(teacher_ner, docs, docs, sgd=optimizer, losses=losses)
|
student_ner.distill(teacher_ner, distill_examples, sgd=optimizer, losses=losses)
|
||||||
assert losses["ner"] < 0.0001
|
assert losses["ner"] < 0.0001
|
||||||
|
|
||||||
# test the trained model
|
# test the trained model
|
||||||
|
|
|
@ -418,11 +418,15 @@ def test_distill():
|
||||||
get_examples=lambda: train_examples, labels=teacher_parser.label_data
|
get_examples=lambda: train_examples, labels=teacher_parser.label_data
|
||||||
)
|
)
|
||||||
|
|
||||||
docs = [eg.predicted for eg in train_examples]
|
distill_examples = [
|
||||||
|
Example.from_dict(teacher.make_doc(t[0]), {}) for t in TRAIN_DATA
|
||||||
|
]
|
||||||
|
|
||||||
for i in range(200):
|
for i in range(200):
|
||||||
losses = {}
|
losses = {}
|
||||||
student_parser.distill(teacher_parser, docs, docs, sgd=optimizer, losses=losses)
|
student_parser.distill(
|
||||||
|
teacher_parser, distill_examples, sgd=optimizer, losses=losses
|
||||||
|
)
|
||||||
assert losses["parser"] < 0.0001
|
assert losses["parser"] < 0.0001
|
||||||
|
|
||||||
test_text = "I like securities."
|
test_text = "I like securities."
|
||||||
|
|
|
@ -217,12 +217,14 @@ def test_distill():
|
||||||
get_examples=lambda: train_examples, labels=teacher_lemmatizer.label_data
|
get_examples=lambda: train_examples, labels=teacher_lemmatizer.label_data
|
||||||
)
|
)
|
||||||
|
|
||||||
docs = [eg.predicted for eg in train_examples]
|
distill_examples = [
|
||||||
|
Example.from_dict(teacher.make_doc(t[0]), {}) for t in TRAIN_DATA
|
||||||
|
]
|
||||||
|
|
||||||
for i in range(50):
|
for i in range(50):
|
||||||
losses = {}
|
losses = {}
|
||||||
student_lemmatizer.distill(
|
student_lemmatizer.distill(
|
||||||
teacher_lemmatizer, docs, docs, sgd=optimizer, losses=losses
|
teacher_lemmatizer, distill_examples, sgd=optimizer, losses=losses
|
||||||
)
|
)
|
||||||
assert losses["trainable_lemmatizer"] < 0.00001
|
assert losses["trainable_lemmatizer"] < 0.00001
|
||||||
|
|
||||||
|
|
|
@ -234,11 +234,15 @@ def test_distill():
|
||||||
get_examples=lambda: train_examples, labels=teacher_tagger.label_data
|
get_examples=lambda: train_examples, labels=teacher_tagger.label_data
|
||||||
)
|
)
|
||||||
|
|
||||||
docs = [eg.predicted for eg in train_examples]
|
distill_examples = [
|
||||||
|
Example.from_dict(teacher.make_doc(t[0]), {}) for t in TRAIN_DATA
|
||||||
|
]
|
||||||
|
|
||||||
for i in range(50):
|
for i in range(50):
|
||||||
losses = {}
|
losses = {}
|
||||||
student_tagger.distill(teacher_tagger, docs, docs, sgd=optimizer, losses=losses)
|
student_tagger.distill(
|
||||||
|
teacher_tagger, distill_examples, sgd=optimizer, losses=losses
|
||||||
|
)
|
||||||
assert losses["tagger"] < 0.00001
|
assert losses["tagger"] < 0.00001
|
||||||
|
|
||||||
test_text = "I like blue eggs"
|
test_text = "I like blue eggs"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user