mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-11 08:42:28 +03:00
Update test
This commit is contained in:
parent
288d88a472
commit
b56434c73b
|
@ -542,8 +542,48 @@ def test_tok2vec_listeners_textcat():
|
||||||
assert [t.tag_ for t in docs[1]] == ["N", "V", "J", "N"]
|
assert [t.tag_ for t in docs[1]] == ["N", "V", "J", "N"]
|
||||||
|
|
||||||
|
|
||||||
def test_tok2vec_distill():
|
cfg_string_distillation = """
|
||||||
orig_config = Config().from_str(cfg_string_multi_textcat)
|
[nlp]
|
||||||
|
lang = "en"
|
||||||
|
pipeline = ["tok2vec","tagger"]
|
||||||
|
|
||||||
|
[components]
|
||||||
|
|
||||||
|
[components.tagger]
|
||||||
|
factory = "tagger"
|
||||||
|
|
||||||
|
[components.tagger.model]
|
||||||
|
@architectures = "spacy.Tagger.v2"
|
||||||
|
nO = null
|
||||||
|
|
||||||
|
[components.tagger.model.tok2vec]
|
||||||
|
@architectures = "spacy.Tok2VecListener.v1"
|
||||||
|
width = ${components.tok2vec.model.encode.width}
|
||||||
|
|
||||||
|
[components.tok2vec]
|
||||||
|
factory = "tok2vec"
|
||||||
|
|
||||||
|
[components.tok2vec.model]
|
||||||
|
@architectures = "spacy.Tok2Vec.v2"
|
||||||
|
|
||||||
|
[components.tok2vec.model.embed]
|
||||||
|
@architectures = "spacy.MultiHashEmbed.v2"
|
||||||
|
width = ${components.tok2vec.model.encode.width}
|
||||||
|
rows = [2000, 1000, 1000, 1000]
|
||||||
|
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
|
||||||
|
include_static_vectors = false
|
||||||
|
|
||||||
|
[components.tok2vec.model.encode]
|
||||||
|
@architectures = "spacy.MaxoutWindowEncoder.v2"
|
||||||
|
width = 96
|
||||||
|
depth = 4
|
||||||
|
window_size = 1
|
||||||
|
maxout_pieces = 3
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_tok2vec_distillation_teacher_annotations():
|
||||||
|
orig_config = Config().from_str(cfg_string_distillation)
|
||||||
teacher_nlp = util.load_model_from_config(
|
teacher_nlp = util.load_model_from_config(
|
||||||
orig_config, auto_fill=True, validate=True
|
orig_config, auto_fill=True, validate=True
|
||||||
)
|
)
|
||||||
|
@ -551,10 +591,6 @@ def test_tok2vec_distill():
|
||||||
orig_config, auto_fill=True, validate=True
|
orig_config, auto_fill=True, validate=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove pipes that don't currently support distillation.
|
|
||||||
teacher_nlp.remove_pipe("textcat_multilabel")
|
|
||||||
student_nlp.remove_pipe("textcat_multilabel")
|
|
||||||
|
|
||||||
train_examples_teacher = []
|
train_examples_teacher = []
|
||||||
train_examples_student = []
|
train_examples_student = []
|
||||||
for t in TRAIN_DATA:
|
for t in TRAIN_DATA:
|
||||||
|
@ -571,39 +607,25 @@ def test_tok2vec_distill():
|
||||||
teacher_nlp.update(train_examples_teacher, sgd=optimizer, losses=losses)
|
teacher_nlp.update(train_examples_teacher, sgd=optimizer, losses=losses)
|
||||||
|
|
||||||
student_nlp.initialize(lambda: train_examples_student)
|
student_nlp.initialize(lambda: train_examples_student)
|
||||||
student_tagger = student_nlp.get_pipe("tagger")
|
|
||||||
|
|
||||||
tagger_tok2vec = student_tagger.model.get_ref("tok2vec")
|
# Since Language.distill creates a copy of the examples to use as
|
||||||
tagger_tok2vec_forward = tagger_tok2vec._func
|
# its internal teacher/student docs, we'll need to monkey-patch the
|
||||||
|
# tok2vec pipe's distill method.
|
||||||
def mock_listener_forward(model: Tok2VecListener, inputs, is_train: bool):
|
|
||||||
model.attrs["last_input"] = inputs
|
|
||||||
return tagger_tok2vec_forward(model, inputs, is_train)
|
|
||||||
|
|
||||||
tagger_tok2vec._func = mock_listener_forward
|
|
||||||
|
|
||||||
# Since Language.distill creates a copy of the student docs to use as
|
|
||||||
# its internal teacher docs, we'll need to monkey-patch the tok2vec pipe's
|
|
||||||
# distill method.
|
|
||||||
student_tok2vec = student_nlp.get_pipe("tok2vec")
|
student_tok2vec = student_nlp.get_pipe("tok2vec")
|
||||||
student_tok2vec._old_distill = student_tok2vec.distill
|
student_tok2vec._old_distill = student_tok2vec.distill
|
||||||
|
|
||||||
def tok2vec_distill_wrapper(
|
def tok2vec_distill_wrapper(
|
||||||
self,
|
self,
|
||||||
teacher_pipe,
|
teacher_pipe,
|
||||||
teacher_docs,
|
examples,
|
||||||
student_docs,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
assert all(not doc.tensor.any() for doc in teacher_docs)
|
assert all(not eg.reference.tensor.any() for eg in examples)
|
||||||
out = self._old_distill(teacher_pipe, teacher_docs, student_docs, **kwargs)
|
out = self._old_distill(teacher_pipe, examples, **kwargs)
|
||||||
assert all(doc.tensor.any() for doc in teacher_docs)
|
assert all(eg.reference.tensor.any() for eg in examples)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
student_tok2vec.distill = tok2vec_distill_wrapper.__get__(student_tok2vec, Tok2Vec)
|
student_tok2vec.distill = tok2vec_distill_wrapper.__get__(student_tok2vec, Tok2Vec)
|
||||||
|
|
||||||
student_docs = [eg.predicted for eg in train_examples_student]
|
|
||||||
student_nlp.distill(
|
student_nlp.distill(
|
||||||
teacher_nlp, student_docs, sgd=optimizer, losses=losses, pipe_map={}
|
teacher_nlp, train_examples_student, sgd=optimizer, losses=losses
|
||||||
)
|
)
|
||||||
assert tagger_tok2vec.attrs["last_input"] == student_docs
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user