From 42fe4edfd7260c413ef23cf87a2408a9b3a7ab28 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= <me@danieldk.eu>
Date: Fri, 8 Dec 2023 20:38:01 +0100
Subject: [PATCH] Add distillation tests with max cut size

And fix endless loop when the max cut size is 0 or 1.
---
 spacy/pipeline/transition_parser.pyx | 2 +-
 spacy/tests/parser/test_ner.py       | 5 ++++-
 spacy/tests/parser/test_parse.py     | 5 ++++-
 3 files changed, 9 insertions(+), 3 deletions(-)

diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx
index 135756c27..fb4db2da9 100644
--- a/spacy/pipeline/transition_parser.pyx
+++ b/spacy/pipeline/transition_parser.pyx
@@ -296,7 +296,7 @@ cdef class Parser(TrainablePipe):
             # batch uniform length. Since we do not have a gold standard
             # sequence, we use the teacher's predictions as the gold
             # standard.
-            max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
+            max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2))
             states = self._init_batch(teacher_step_model, student_docs, max_moves)
         else:
             states = self.moves.init_batch(student_docs)
diff --git a/spacy/tests/parser/test_ner.py b/spacy/tests/parser/test_ner.py
index b6848d380..7c3a9d562 100644
--- a/spacy/tests/parser/test_ner.py
+++ b/spacy/tests/parser/test_ner.py
@@ -624,7 +624,9 @@ def test_is_distillable():
     assert ner.is_distillable
 
 
-def test_distill():
+@pytest.mark.slow
+@pytest.mark.parametrize("max_moves", [0, 1, 5, 100])
+def test_distill(max_moves):
     teacher = English()
     teacher_ner = teacher.add_pipe("ner")
     train_examples = []
@@ -642,6 +644,7 @@ def test_distill():
 
     student = English()
     student_ner = student.add_pipe("ner")
+    student_ner.cfg["update_with_oracle_cut_size"] = max_moves
     student_ner.initialize(
         get_examples=lambda: train_examples, labels=teacher_ner.label_data
     )
diff --git a/spacy/tests/parser/test_parse.py b/spacy/tests/parser/test_parse.py
index 42cf5ced9..dbede7edd 100644
--- a/spacy/tests/parser/test_parse.py
+++ b/spacy/tests/parser/test_parse.py
@@ -402,7 +402,9 @@ def test_is_distillable():
     assert parser.is_distillable
 
 
-def test_distill():
+@pytest.mark.slow
+@pytest.mark.parametrize("max_moves", [0, 1, 5, 100])
+def test_distill(max_moves):
     teacher = English()
     teacher_parser = teacher.add_pipe("parser")
     train_examples = []
@@ -420,6 +422,7 @@ def test_distill():
 
     student = English()
     student_parser = student.add_pipe("parser")
+    student_parser.cfg["update_with_oracle_cut_size"] = max_moves
     student_parser.initialize(
         get_examples=lambda: train_examples, labels=teacher_parser.label_data
     )