mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 16:22:29 +03:00
Add distillation tests with max cut size
And fix endless loop when the max cut size is 0 or 1.
This commit is contained in:
parent
e2591cda36
commit
42fe4edfd7
|
@ -296,7 +296,7 @@ cdef class Parser(TrainablePipe):
|
||||||
# batch uniform length. Since we do not have a gold standard
|
# batch uniform length. Since we do not have a gold standard
|
||||||
# sequence, we use the teacher's predictions as the gold
|
# sequence, we use the teacher's predictions as the gold
|
||||||
# standard.
|
# standard.
|
||||||
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
|
max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2))
|
||||||
states = self._init_batch(teacher_step_model, student_docs, max_moves)
|
states = self._init_batch(teacher_step_model, student_docs, max_moves)
|
||||||
else:
|
else:
|
||||||
states = self.moves.init_batch(student_docs)
|
states = self.moves.init_batch(student_docs)
|
||||||
|
|
|
@ -624,7 +624,9 @@ def test_is_distillable():
|
||||||
assert ner.is_distillable
|
assert ner.is_distillable
|
||||||
|
|
||||||
|
|
||||||
def test_distill():
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.parametrize("max_moves", [0, 1, 5, 100])
|
||||||
|
def test_distill(max_moves):
|
||||||
teacher = English()
|
teacher = English()
|
||||||
teacher_ner = teacher.add_pipe("ner")
|
teacher_ner = teacher.add_pipe("ner")
|
||||||
train_examples = []
|
train_examples = []
|
||||||
|
@ -642,6 +644,7 @@ def test_distill():
|
||||||
|
|
||||||
student = English()
|
student = English()
|
||||||
student_ner = student.add_pipe("ner")
|
student_ner = student.add_pipe("ner")
|
||||||
|
student_ner.cfg["update_with_oracle_cut_size"] = max_moves
|
||||||
student_ner.initialize(
|
student_ner.initialize(
|
||||||
get_examples=lambda: train_examples, labels=teacher_ner.label_data
|
get_examples=lambda: train_examples, labels=teacher_ner.label_data
|
||||||
)
|
)
|
||||||
|
|
|
@ -402,7 +402,9 @@ def test_is_distillable():
|
||||||
assert parser.is_distillable
|
assert parser.is_distillable
|
||||||
|
|
||||||
|
|
||||||
def test_distill():
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.parametrize("max_moves", [0, 1, 5, 100])
|
||||||
|
def test_distill(max_moves):
|
||||||
teacher = English()
|
teacher = English()
|
||||||
teacher_parser = teacher.add_pipe("parser")
|
teacher_parser = teacher.add_pipe("parser")
|
||||||
train_examples = []
|
train_examples = []
|
||||||
|
@ -420,6 +422,7 @@ def test_distill():
|
||||||
|
|
||||||
student = English()
|
student = English()
|
||||||
student_parser = student.add_pipe("parser")
|
student_parser = student.add_pipe("parser")
|
||||||
|
student_parser.cfg["update_with_oracle_cut_size"] = max_moves
|
||||||
student_parser.initialize(
|
student_parser.initialize(
|
||||||
get_examples=lambda: train_examples, labels=teacher_parser.label_data
|
get_examples=lambda: train_examples, labels=teacher_parser.label_data
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user