mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Update test
This commit is contained in:
		
							parent
							
								
									288d88a472
								
							
						
					
					
						commit
						b56434c73b
					
				| 
						 | 
					@ -542,8 +542,48 @@ def test_tok2vec_listeners_textcat():
 | 
				
			||||||
    assert [t.tag_ for t in docs[1]] == ["N", "V", "J", "N"]
 | 
					    assert [t.tag_ for t in docs[1]] == ["N", "V", "J", "N"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_tok2vec_distill():
 | 
					cfg_string_distillation = """
 | 
				
			||||||
    orig_config = Config().from_str(cfg_string_multi_textcat)
 | 
					    [nlp]
 | 
				
			||||||
 | 
					    lang = "en"
 | 
				
			||||||
 | 
					    pipeline = ["tok2vec","tagger"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    [components]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    [components.tagger]
 | 
				
			||||||
 | 
					    factory = "tagger"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    [components.tagger.model]
 | 
				
			||||||
 | 
					    @architectures = "spacy.Tagger.v2"
 | 
				
			||||||
 | 
					    nO = null
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    [components.tagger.model.tok2vec]
 | 
				
			||||||
 | 
					    @architectures = "spacy.Tok2VecListener.v1"
 | 
				
			||||||
 | 
					    width = ${components.tok2vec.model.encode.width}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    [components.tok2vec]
 | 
				
			||||||
 | 
					    factory = "tok2vec"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    [components.tok2vec.model]
 | 
				
			||||||
 | 
					    @architectures = "spacy.Tok2Vec.v2"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    [components.tok2vec.model.embed]
 | 
				
			||||||
 | 
					    @architectures = "spacy.MultiHashEmbed.v2"
 | 
				
			||||||
 | 
					    width = ${components.tok2vec.model.encode.width}
 | 
				
			||||||
 | 
					    rows = [2000, 1000, 1000, 1000]
 | 
				
			||||||
 | 
					    attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
 | 
				
			||||||
 | 
					    include_static_vectors = false
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    [components.tok2vec.model.encode]
 | 
				
			||||||
 | 
					    @architectures = "spacy.MaxoutWindowEncoder.v2"
 | 
				
			||||||
 | 
					    width = 96
 | 
				
			||||||
 | 
					    depth = 4
 | 
				
			||||||
 | 
					    window_size = 1
 | 
				
			||||||
 | 
					    maxout_pieces = 3
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_tok2vec_distillation_teacher_annotations():
 | 
				
			||||||
 | 
					    orig_config = Config().from_str(cfg_string_distillation)
 | 
				
			||||||
    teacher_nlp = util.load_model_from_config(
 | 
					    teacher_nlp = util.load_model_from_config(
 | 
				
			||||||
        orig_config, auto_fill=True, validate=True
 | 
					        orig_config, auto_fill=True, validate=True
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
| 
						 | 
					@ -551,10 +591,6 @@ def test_tok2vec_distill():
 | 
				
			||||||
        orig_config, auto_fill=True, validate=True
 | 
					        orig_config, auto_fill=True, validate=True
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Remove pipes that don't currently support distillation.
 | 
					 | 
				
			||||||
    teacher_nlp.remove_pipe("textcat_multilabel")
 | 
					 | 
				
			||||||
    student_nlp.remove_pipe("textcat_multilabel")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    train_examples_teacher = []
 | 
					    train_examples_teacher = []
 | 
				
			||||||
    train_examples_student = []
 | 
					    train_examples_student = []
 | 
				
			||||||
    for t in TRAIN_DATA:
 | 
					    for t in TRAIN_DATA:
 | 
				
			||||||
| 
						 | 
					@ -571,39 +607,25 @@ def test_tok2vec_distill():
 | 
				
			||||||
        teacher_nlp.update(train_examples_teacher, sgd=optimizer, losses=losses)
 | 
					        teacher_nlp.update(train_examples_teacher, sgd=optimizer, losses=losses)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    student_nlp.initialize(lambda: train_examples_student)
 | 
					    student_nlp.initialize(lambda: train_examples_student)
 | 
				
			||||||
    student_tagger = student_nlp.get_pipe("tagger")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    tagger_tok2vec = student_tagger.model.get_ref("tok2vec")
 | 
					    # Since Language.distill creates a copy of the examples to use as
 | 
				
			||||||
    tagger_tok2vec_forward = tagger_tok2vec._func
 | 
					    # its internal teacher/student docs, we'll need to monkey-patch the
 | 
				
			||||||
 | 
					    # tok2vec pipe's distill method.
 | 
				
			||||||
    def mock_listener_forward(model: Tok2VecListener, inputs, is_train: bool):
 | 
					 | 
				
			||||||
        model.attrs["last_input"] = inputs
 | 
					 | 
				
			||||||
        return tagger_tok2vec_forward(model, inputs, is_train)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    tagger_tok2vec._func = mock_listener_forward
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Since Language.distill creates a copy of the student docs to use as
 | 
					 | 
				
			||||||
    # its internal teacher docs, we'll need to monkey-patch the tok2vec pipe's
 | 
					 | 
				
			||||||
    # distill method.
 | 
					 | 
				
			||||||
    student_tok2vec = student_nlp.get_pipe("tok2vec")
 | 
					    student_tok2vec = student_nlp.get_pipe("tok2vec")
 | 
				
			||||||
    student_tok2vec._old_distill = student_tok2vec.distill
 | 
					    student_tok2vec._old_distill = student_tok2vec.distill
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def tok2vec_distill_wrapper(
 | 
					    def tok2vec_distill_wrapper(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        teacher_pipe,
 | 
					        teacher_pipe,
 | 
				
			||||||
        teacher_docs,
 | 
					        examples,
 | 
				
			||||||
        student_docs,
 | 
					 | 
				
			||||||
        **kwargs,
 | 
					        **kwargs,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        assert all(not doc.tensor.any() for doc in teacher_docs)
 | 
					        assert all(not eg.reference.tensor.any() for eg in examples)
 | 
				
			||||||
        out = self._old_distill(teacher_pipe, teacher_docs, student_docs, **kwargs)
 | 
					        out = self._old_distill(teacher_pipe, examples, **kwargs)
 | 
				
			||||||
        assert all(doc.tensor.any() for doc in teacher_docs)
 | 
					        assert all(eg.reference.tensor.any() for eg in examples)
 | 
				
			||||||
        return out
 | 
					        return out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    student_tok2vec.distill = tok2vec_distill_wrapper.__get__(student_tok2vec, Tok2Vec)
 | 
					    student_tok2vec.distill = tok2vec_distill_wrapper.__get__(student_tok2vec, Tok2Vec)
 | 
				
			||||||
 | 
					 | 
				
			||||||
    student_docs = [eg.predicted for eg in train_examples_student]
 | 
					 | 
				
			||||||
    student_nlp.distill(
 | 
					    student_nlp.distill(
 | 
				
			||||||
        teacher_nlp, student_docs, sgd=optimizer, losses=losses, pipe_map={}
 | 
					        teacher_nlp, train_examples_student, sgd=optimizer, losses=losses
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    assert tagger_tok2vec.attrs["last_input"] == student_docs
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user