diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index 135756c27..fb4db2da9 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -296,7 +296,7 @@ cdef class Parser(TrainablePipe): # batch uniform length. Since we do not have a gold standard # sequence, we use the teacher's predictions as the gold # standard. - max_moves = int(random.uniform(max_moves // 2, max_moves * 2)) + max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2)) states = self._init_batch(teacher_step_model, student_docs, max_moves) else: states = self.moves.init_batch(student_docs) diff --git a/spacy/tests/parser/test_ner.py b/spacy/tests/parser/test_ner.py index b6848d380..7c3a9d562 100644 --- a/spacy/tests/parser/test_ner.py +++ b/spacy/tests/parser/test_ner.py @@ -624,7 +624,9 @@ def test_is_distillable(): assert ner.is_distillable -def test_distill(): +@pytest.mark.slow +@pytest.mark.parametrize("max_moves", [0, 1, 5, 100]) +def test_distill(max_moves): teacher = English() teacher_ner = teacher.add_pipe("ner") train_examples = [] @@ -642,6 +644,7 @@ def test_distill(): student = English() student_ner = student.add_pipe("ner") + student_ner.cfg["update_with_oracle_cut_size"] = max_moves student_ner.initialize( get_examples=lambda: train_examples, labels=teacher_ner.label_data ) diff --git a/spacy/tests/parser/test_parse.py b/spacy/tests/parser/test_parse.py index 42cf5ced9..dbede7edd 100644 --- a/spacy/tests/parser/test_parse.py +++ b/spacy/tests/parser/test_parse.py @@ -402,7 +402,9 @@ def test_is_distillable(): assert parser.is_distillable -def test_distill(): +@pytest.mark.slow +@pytest.mark.parametrize("max_moves", [0, 1, 5, 100]) +def test_distill(max_moves): teacher = English() teacher_parser = teacher.add_pipe("parser") train_examples = [] @@ -420,6 +422,7 @@ def test_distill(): student = English() student_parser = student.add_pipe("parser") + student_parser.cfg["update_with_oracle_cut_size"] = max_moves student_parser.initialize( get_examples=lambda: train_examples, labels=teacher_parser.label_data )