Test distillation max cuts in NER

This commit is contained in:
Daniël de Kok 2023-02-01 13:34:31 +01:00
parent 856f5a86f9
commit 8f1acfbc40

View File

@ -623,7 +623,8 @@ def test_is_distillable():
assert ner.is_distillable
def test_distill():
@pytest.mark.parametrize("max_moves", [0, 1, 5, 100])
def test_distill(max_moves):
teacher = English()
teacher_ner = teacher.add_pipe("ner")
train_examples = []
@ -641,6 +642,7 @@ def test_distill():
student = English()
student_ner = student.add_pipe("ner")
student_ner.cfg["update_with_oracle_cut_size"] = max_moves
student_ner.initialize(
get_examples=lambda: train_examples, labels=teacher_ner.label_data
)