Add Tok2Vec.distill test

This commit is contained in:
shademe 2022-12-30 16:04:55 +01:00
parent c21ebb31ec
commit 38929c2ca9
No known key found for this signature in database
GPG Key ID: 6FCA9FC635B2A402

View File

@ -540,3 +540,70 @@ def test_tok2vec_listeners_textcat():
assert cats1["imperative"] < 0.9
assert [t.tag_ for t in docs[0]] == ["V", "J", "N"]
assert [t.tag_ for t in docs[1]] == ["N", "V", "J", "N"]
def test_tok2vec_distill():
orig_config = Config().from_str(cfg_string_multi_textcat)
teacher_nlp = util.load_model_from_config(
orig_config, auto_fill=True, validate=True
)
student_nlp = util.load_model_from_config(
orig_config, auto_fill=True, validate=True
)
# Remove pipes that don't currently support distillation.
teacher_nlp.remove_pipe("textcat_multilabel")
student_nlp.remove_pipe("textcat_multilabel")
train_examples_teacher = []
train_examples_student = []
for t in TRAIN_DATA:
train_examples_teacher.append(
Example.from_dict(teacher_nlp.make_doc(t[0]), t[1])
)
train_examples_student.append(
Example.from_dict(student_nlp.make_doc(t[0]), t[1])
)
optimizer = teacher_nlp.initialize(lambda: train_examples_teacher)
for i in range(50):
losses = {}
teacher_nlp.update(train_examples_teacher, sgd=optimizer, losses=losses)
student_nlp.initialize(lambda: train_examples_student)
student_tagger = student_nlp.get_pipe("tagger")
tagger_tok2vec = student_tagger.model.get_ref("tok2vec")
tagger_tok2vec_forward = tagger_tok2vec._func
def mock_listener_forward(model: Tok2VecListener, inputs, is_train: bool):
model.attrs["last_input"] = inputs
return tagger_tok2vec_forward(model, inputs, is_train)
tagger_tok2vec._func = mock_listener_forward
# Since Language.distill creates a copy of the student docs to use as
# its internal teacher docs, we'll need to monkey-patch the tok2vec pipe's
# distill method.
student_tok2vec = student_nlp.get_pipe("tok2vec")
student_tok2vec._old_distill = student_tok2vec.distill
def tok2vec_distill_wrapper(
self,
teacher_pipe,
teacher_docs,
student_docs,
**kwargs,
):
assert all(not doc.tensor.any() for doc in teacher_docs)
out = self._old_distill(teacher_pipe, teacher_docs, student_docs, **kwargs)
assert all(doc.tensor.any() for doc in teacher_docs)
return out
student_tok2vec.distill = tok2vec_distill_wrapper.__get__(student_tok2vec, Tok2Vec)
student_docs = [eg.predicted for eg in train_examples_student]
student_nlp.distill(
teacher_nlp, student_docs, sgd=optimizer, losses=losses, pipe_map={}
)
assert tagger_tok2vec.attrs["last_input"] == student_docs