Add distillation tests with max cut size

And fix endless loop when the max cut size is 0 or 1.
This commit is contained in:
Daniël de Kok 2023-12-08 20:38:01 +01:00
parent e2591cda36
commit 42fe4edfd7
3 changed files with 9 additions and 3 deletions

View File

@ -296,7 +296,7 @@ cdef class Parser(TrainablePipe):
# batch uniform length. Since we do not have a gold standard
# sequence, we use the teacher's predictions as the gold
# standard.
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2))
states = self._init_batch(teacher_step_model, student_docs, max_moves)
else:
states = self.moves.init_batch(student_docs)

View File

@ -624,7 +624,9 @@ def test_is_distillable():
assert ner.is_distillable
def test_distill():
@pytest.mark.slow
@pytest.mark.parametrize("max_moves", [0, 1, 5, 100])
def test_distill(max_moves):
teacher = English()
teacher_ner = teacher.add_pipe("ner")
train_examples = []
@ -642,6 +644,7 @@ def test_distill():
student = English()
student_ner = student.add_pipe("ner")
student_ner.cfg["update_with_oracle_cut_size"] = max_moves
student_ner.initialize(
get_examples=lambda: train_examples, labels=teacher_ner.label_data
)

View File

@ -402,7 +402,9 @@ def test_is_distillable():
assert parser.is_distillable
def test_distill():
@pytest.mark.slow
@pytest.mark.parametrize("max_moves", [0, 1, 5, 100])
def test_distill(max_moves):
teacher = English()
teacher_parser = teacher.add_pipe("parser")
train_examples = []
@ -420,6 +422,7 @@ def test_distill():
student = English()
student_parser = student.add_pipe("parser")
student_parser.cfg["update_with_oracle_cut_size"] = max_moves
student_parser.initialize(
get_examples=lambda: train_examples, labels=teacher_parser.label_data
)