mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Add distillation tests with max cut size
And fix endless loop when the max cut size is 0 or 1.
This commit is contained in:
		
							parent
							
								
									e2591cda36
								
							
						
					
					
						commit
						42fe4edfd7
					
				| 
						 | 
					@ -296,7 +296,7 @@ cdef class Parser(TrainablePipe):
 | 
				
			||||||
            # batch uniform length. Since we do not have a gold standard
 | 
					            # batch uniform length. Since we do not have a gold standard
 | 
				
			||||||
            # sequence, we use the teacher's predictions as the gold
 | 
					            # sequence, we use the teacher's predictions as the gold
 | 
				
			||||||
            # standard.
 | 
					            # standard.
 | 
				
			||||||
            max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
 | 
					            max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2))
 | 
				
			||||||
            states = self._init_batch(teacher_step_model, student_docs, max_moves)
 | 
					            states = self._init_batch(teacher_step_model, student_docs, max_moves)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            states = self.moves.init_batch(student_docs)
 | 
					            states = self.moves.init_batch(student_docs)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -624,7 +624,9 @@ def test_is_distillable():
 | 
				
			||||||
    assert ner.is_distillable
 | 
					    assert ner.is_distillable
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_distill():
 | 
					@pytest.mark.slow
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("max_moves", [0, 1, 5, 100])
 | 
				
			||||||
 | 
					def test_distill(max_moves):
 | 
				
			||||||
    teacher = English()
 | 
					    teacher = English()
 | 
				
			||||||
    teacher_ner = teacher.add_pipe("ner")
 | 
					    teacher_ner = teacher.add_pipe("ner")
 | 
				
			||||||
    train_examples = []
 | 
					    train_examples = []
 | 
				
			||||||
| 
						 | 
					@ -642,6 +644,7 @@ def test_distill():
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    student = English()
 | 
					    student = English()
 | 
				
			||||||
    student_ner = student.add_pipe("ner")
 | 
					    student_ner = student.add_pipe("ner")
 | 
				
			||||||
 | 
					    student_ner.cfg["update_with_oracle_cut_size"] = max_moves
 | 
				
			||||||
    student_ner.initialize(
 | 
					    student_ner.initialize(
 | 
				
			||||||
        get_examples=lambda: train_examples, labels=teacher_ner.label_data
 | 
					        get_examples=lambda: train_examples, labels=teacher_ner.label_data
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -402,7 +402,9 @@ def test_is_distillable():
 | 
				
			||||||
    assert parser.is_distillable
 | 
					    assert parser.is_distillable
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_distill():
 | 
					@pytest.mark.slow
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("max_moves", [0, 1, 5, 100])
 | 
				
			||||||
 | 
					def test_distill(max_moves):
 | 
				
			||||||
    teacher = English()
 | 
					    teacher = English()
 | 
				
			||||||
    teacher_parser = teacher.add_pipe("parser")
 | 
					    teacher_parser = teacher.add_pipe("parser")
 | 
				
			||||||
    train_examples = []
 | 
					    train_examples = []
 | 
				
			||||||
| 
						 | 
					@ -420,6 +422,7 @@ def test_distill():
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    student = English()
 | 
					    student = English()
 | 
				
			||||||
    student_parser = student.add_pipe("parser")
 | 
					    student_parser = student.add_pipe("parser")
 | 
				
			||||||
 | 
					    student_parser.cfg["update_with_oracle_cut_size"] = max_moves
 | 
				
			||||||
    student_parser.initialize(
 | 
					    student_parser.initialize(
 | 
				
			||||||
        get_examples=lambda: train_examples, labels=teacher_parser.label_data
 | 
					        get_examples=lambda: train_examples, labels=teacher_parser.label_data
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user