mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Add distillation tests with max cut size
And fix endless loop when the max cut size is 0 or 1.
This commit is contained in:
		
							parent
							
								
									e2591cda36
								
							
						
					
					
						commit
						42fe4edfd7
					
				| 
						 | 
				
			
			@ -296,7 +296,7 @@ cdef class Parser(TrainablePipe):
 | 
			
		|||
            # batch uniform length. Since we do not have a gold standard
 | 
			
		||||
            # sequence, we use the teacher's predictions as the gold
 | 
			
		||||
            # standard.
 | 
			
		||||
            max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
 | 
			
		||||
            max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2))
 | 
			
		||||
            states = self._init_batch(teacher_step_model, student_docs, max_moves)
 | 
			
		||||
        else:
 | 
			
		||||
            states = self.moves.init_batch(student_docs)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -624,7 +624,9 @@ def test_is_distillable():
 | 
			
		|||
    assert ner.is_distillable
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_distill():
 | 
			
		||||
@pytest.mark.slow
 | 
			
		||||
@pytest.mark.parametrize("max_moves", [0, 1, 5, 100])
 | 
			
		||||
def test_distill(max_moves):
 | 
			
		||||
    teacher = English()
 | 
			
		||||
    teacher_ner = teacher.add_pipe("ner")
 | 
			
		||||
    train_examples = []
 | 
			
		||||
| 
						 | 
				
			
			@ -642,6 +644,7 @@ def test_distill():
 | 
			
		|||
 | 
			
		||||
    student = English()
 | 
			
		||||
    student_ner = student.add_pipe("ner")
 | 
			
		||||
    student_ner.cfg["update_with_oracle_cut_size"] = max_moves
 | 
			
		||||
    student_ner.initialize(
 | 
			
		||||
        get_examples=lambda: train_examples, labels=teacher_ner.label_data
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -402,7 +402,9 @@ def test_is_distillable():
 | 
			
		|||
    assert parser.is_distillable
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_distill():
 | 
			
		||||
@pytest.mark.slow
 | 
			
		||||
@pytest.mark.parametrize("max_moves", [0, 1, 5, 100])
 | 
			
		||||
def test_distill(max_moves):
 | 
			
		||||
    teacher = English()
 | 
			
		||||
    teacher_parser = teacher.add_pipe("parser")
 | 
			
		||||
    train_examples = []
 | 
			
		||||
| 
						 | 
				
			
			@ -420,6 +422,7 @@ def test_distill():
 | 
			
		|||
 | 
			
		||||
    student = English()
 | 
			
		||||
    student_parser = student.add_pipe("parser")
 | 
			
		||||
    student_parser.cfg["update_with_oracle_cut_size"] = max_moves
 | 
			
		||||
    student_parser.initialize(
 | 
			
		||||
        get_examples=lambda: train_examples, labels=teacher_parser.label_data
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user