mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 18:56:36 +03:00
Add distillation tests with max cut size
And fix endless loop when the max cut size is 0 or 1.
This commit is contained in:
parent
e2591cda36
commit
42fe4edfd7
|
@ -296,7 +296,7 @@ cdef class Parser(TrainablePipe):
|
|||
# batch uniform length. Since we do not have a gold standard
|
||||
# sequence, we use the teacher's predictions as the gold
|
||||
# standard.
|
||||
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
|
||||
max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2))
|
||||
states = self._init_batch(teacher_step_model, student_docs, max_moves)
|
||||
else:
|
||||
states = self.moves.init_batch(student_docs)
|
||||
|
|
|
@ -624,7 +624,9 @@ def test_is_distillable():
|
|||
assert ner.is_distillable
|
||||
|
||||
|
||||
def test_distill():
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("max_moves", [0, 1, 5, 100])
|
||||
def test_distill(max_moves):
|
||||
teacher = English()
|
||||
teacher_ner = teacher.add_pipe("ner")
|
||||
train_examples = []
|
||||
|
@ -642,6 +644,7 @@ def test_distill():
|
|||
|
||||
student = English()
|
||||
student_ner = student.add_pipe("ner")
|
||||
student_ner.cfg["update_with_oracle_cut_size"] = max_moves
|
||||
student_ner.initialize(
|
||||
get_examples=lambda: train_examples, labels=teacher_ner.label_data
|
||||
)
|
||||
|
|
|
@ -402,7 +402,9 @@ def test_is_distillable():
|
|||
assert parser.is_distillable
|
||||
|
||||
|
||||
def test_distill():
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("max_moves", [0, 1, 5, 100])
|
||||
def test_distill(max_moves):
|
||||
teacher = English()
|
||||
teacher_parser = teacher.add_pipe("parser")
|
||||
train_examples = []
|
||||
|
@ -420,6 +422,7 @@ def test_distill():
|
|||
|
||||
student = English()
|
||||
student_parser = student.add_pipe("parser")
|
||||
student_parser.cfg["update_with_oracle_cut_size"] = max_moves
|
||||
student_parser.initialize(
|
||||
get_examples=lambda: train_examples, labels=teacher_parser.label_data
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user