Test distillation max cuts in NER

This commit is contained in:
Daniël de Kok 2023-02-01 13:34:31 +01:00
parent 856f5a86f9
commit 8f1acfbc40

View File

@ -623,7 +623,8 @@ def test_is_distillable():
assert ner.is_distillable assert ner.is_distillable
def test_distill(): @pytest.mark.parametrize("max_moves", [0, 1, 5, 100])
def test_distill(max_moves):
teacher = English() teacher = English()
teacher_ner = teacher.add_pipe("ner") teacher_ner = teacher.add_pipe("ner")
train_examples = [] train_examples = []
@ -641,6 +642,7 @@ def test_distill():
student = English() student = English()
student_ner = student.add_pipe("ner") student_ner = student.add_pipe("ner")
student_ner.cfg["update_with_oracle_cut_size"] = max_moves
student_ner.initialize( student_ner.initialize(
get_examples=lambda: train_examples, labels=teacher_ner.label_data get_examples=lambda: train_examples, labels=teacher_ner.label_data
) )