mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-25 13:11:03 +03:00 
			
		
		
		
	* Add `TrainablePipe.{distill,get_teacher_student_loss}`
This change adds two methods:
- `TrainablePipe::distill` which performs a training step of a
   student pipe on a teacher pipe, giving a batch of `Doc`s.
- `TrainablePipe::get_teacher_student_loss` computes the loss
  of a student relative to the teacher.
The `distill` or `get_teacher_student_loss` methods are also implemented
in the tagger, edit tree lemmatizer, and parser pipes, to enable
distillation in those pipes and as an example for other pipes.
* Fix stray `Beam` import
* Fix incorrect import
* Apply suggestions from code review
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
* Apply suggestions from code review
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
* TrainablePipe.distill: use `Iterable[Example]`
* Add Pipe.is_distillable method
* Add `validate_distillation_examples`
This first calls `validate_examples` and then checks that the
student/teacher tokens are the same.
* Update distill documentation
* Add distill documentation for all pipes that support distillation
* Fix incorrect identifier
* Apply suggestions from code review
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
* Add comment to explain `is_distillable`
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
		
	
			
		
			
				
	
	
		
			388 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			388 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from typing import cast
 | |
| import pickle
 | |
| import pytest
 | |
| from hypothesis import given
 | |
| import hypothesis.strategies as st
 | |
| from spacy import util
 | |
| from spacy.lang.en import English
 | |
| from spacy.language import Language
 | |
| from spacy.pipeline._edit_tree_internals.edit_trees import EditTrees
 | |
| from spacy.pipeline.trainable_pipe import TrainablePipe
 | |
| from spacy.training import Example
 | |
| from spacy.strings import StringStore
 | |
| from spacy.util import make_tempdir
 | |
| 
 | |
| 
 | |
| TRAIN_DATA = [
 | |
|     ("She likes green eggs", {"lemmas": ["she", "like", "green", "egg"]}),
 | |
|     ("Eat blue ham", {"lemmas": ["eat", "blue", "ham"]}),
 | |
| ]
 | |
| 
 | |
| PARTIAL_DATA = [
 | |
|     # partial annotation
 | |
|     ("She likes green eggs", {"lemmas": ["", "like", "green", ""]}),
 | |
|     # misaligned partial annotation
 | |
|     (
 | |
|         "He hates green eggs",
 | |
|         {
 | |
|             "words": ["He", "hat", "es", "green", "eggs"],
 | |
|             "lemmas": ["", "hat", "e", "green", ""],
 | |
|         },
 | |
|     ),
 | |
| ]
 | |
| 
 | |
| 
 | |
| def test_initialize_examples():
 | |
|     nlp = Language()
 | |
|     lemmatizer = nlp.add_pipe("trainable_lemmatizer")
 | |
|     train_examples = []
 | |
|     for t in TRAIN_DATA:
 | |
|         train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
 | |
|     # you shouldn't really call this more than once, but for testing it should be fine
 | |
|     nlp.initialize(get_examples=lambda: train_examples)
 | |
|     with pytest.raises(TypeError):
 | |
|         nlp.initialize(get_examples=lambda: None)
 | |
|     with pytest.raises(TypeError):
 | |
|         nlp.initialize(get_examples=lambda: train_examples[0])
 | |
|     with pytest.raises(TypeError):
 | |
|         nlp.initialize(get_examples=lambda: [])
 | |
|     with pytest.raises(TypeError):
 | |
|         nlp.initialize(get_examples=train_examples)
 | |
| 
 | |
| 
 | |
| def test_initialize_from_labels():
 | |
|     nlp = Language()
 | |
|     lemmatizer = nlp.add_pipe("trainable_lemmatizer")
 | |
|     lemmatizer.min_tree_freq = 1
 | |
|     train_examples = []
 | |
|     for t in TRAIN_DATA:
 | |
|         train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
 | |
|     nlp.initialize(get_examples=lambda: train_examples)
 | |
| 
 | |
|     nlp2 = Language()
 | |
|     lemmatizer2 = nlp2.add_pipe("trainable_lemmatizer")
 | |
|     lemmatizer2.initialize(
 | |
|         # We want to check that the strings in replacement nodes are
 | |
|         # added to the string store. Avoid that they get added through
 | |
|         # the examples.
 | |
|         get_examples=lambda: train_examples[:1],
 | |
|         labels=lemmatizer.label_data,
 | |
|     )
 | |
|     assert lemmatizer2.tree2label == {1: 0, 3: 1, 4: 2, 6: 3}
 | |
|     assert lemmatizer2.label_data == {
 | |
|         "trees": [
 | |
|             {"orig": "S", "subst": "s"},
 | |
|             {
 | |
|                 "prefix_len": 1,
 | |
|                 "suffix_len": 0,
 | |
|                 "prefix_tree": 0,
 | |
|                 "suffix_tree": 4294967295,
 | |
|             },
 | |
|             {"orig": "s", "subst": ""},
 | |
|             {
 | |
|                 "prefix_len": 0,
 | |
|                 "suffix_len": 1,
 | |
|                 "prefix_tree": 4294967295,
 | |
|                 "suffix_tree": 2,
 | |
|             },
 | |
|             {
 | |
|                 "prefix_len": 0,
 | |
|                 "suffix_len": 0,
 | |
|                 "prefix_tree": 4294967295,
 | |
|                 "suffix_tree": 4294967295,
 | |
|             },
 | |
|             {"orig": "E", "subst": "e"},
 | |
|             {
 | |
|                 "prefix_len": 1,
 | |
|                 "suffix_len": 0,
 | |
|                 "prefix_tree": 5,
 | |
|                 "suffix_tree": 4294967295,
 | |
|             },
 | |
|         ],
 | |
|         "labels": (1, 3, 4, 6),
 | |
|     }
 | |
| 
 | |
| 
 | |
| def test_no_data():
 | |
|     # Test that the lemmatizer provides a nice error when there's no tagging data / labels
 | |
|     TEXTCAT_DATA = [
 | |
|         ("I'm so happy.", {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}),
 | |
|         ("I'm so angry", {"cats": {"POSITIVE": 0.0, "NEGATIVE": 1.0}}),
 | |
|     ]
 | |
|     nlp = English()
 | |
|     nlp.add_pipe("trainable_lemmatizer")
 | |
|     nlp.add_pipe("textcat")
 | |
| 
 | |
|     train_examples = []
 | |
|     for t in TEXTCAT_DATA:
 | |
|         train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
 | |
| 
 | |
|     with pytest.raises(ValueError):
 | |
|         nlp.initialize(get_examples=lambda: train_examples)
 | |
| 
 | |
| 
 | |
| def test_incomplete_data():
 | |
|     # Test that the lemmatizer works with incomplete information
 | |
|     nlp = English()
 | |
|     lemmatizer = nlp.add_pipe("trainable_lemmatizer")
 | |
|     lemmatizer.min_tree_freq = 1
 | |
|     train_examples = []
 | |
|     for t in PARTIAL_DATA:
 | |
|         train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
 | |
|     optimizer = nlp.initialize(get_examples=lambda: train_examples)
 | |
|     for i in range(50):
 | |
|         losses = {}
 | |
|         nlp.update(train_examples, sgd=optimizer, losses=losses)
 | |
|     assert losses["trainable_lemmatizer"] < 0.00001
 | |
| 
 | |
|     # test the trained model
 | |
|     test_text = "She likes blue eggs"
 | |
|     doc = nlp(test_text)
 | |
|     assert doc[1].lemma_ == "like"
 | |
|     assert doc[2].lemma_ == "blue"
 | |
| 
 | |
| 
 | |
| def test_overfitting_IO():
 | |
|     nlp = English()
 | |
|     lemmatizer = nlp.add_pipe("trainable_lemmatizer")
 | |
|     lemmatizer.min_tree_freq = 1
 | |
|     train_examples = []
 | |
|     for t in TRAIN_DATA:
 | |
|         train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
 | |
| 
 | |
|     optimizer = nlp.initialize(get_examples=lambda: train_examples)
 | |
| 
 | |
|     for i in range(50):
 | |
|         losses = {}
 | |
|         nlp.update(train_examples, sgd=optimizer, losses=losses)
 | |
|     assert losses["trainable_lemmatizer"] < 0.00001
 | |
| 
 | |
|     test_text = "She likes blue eggs"
 | |
|     doc = nlp(test_text)
 | |
|     assert doc[0].lemma_ == "she"
 | |
|     assert doc[1].lemma_ == "like"
 | |
|     assert doc[2].lemma_ == "blue"
 | |
|     assert doc[3].lemma_ == "egg"
 | |
| 
 | |
|     # Check model after a {to,from}_disk roundtrip
 | |
|     with util.make_tempdir() as tmp_dir:
 | |
|         nlp.to_disk(tmp_dir)
 | |
|         nlp2 = util.load_model_from_path(tmp_dir)
 | |
|         doc2 = nlp2(test_text)
 | |
|         assert doc2[0].lemma_ == "she"
 | |
|         assert doc2[1].lemma_ == "like"
 | |
|         assert doc2[2].lemma_ == "blue"
 | |
|         assert doc2[3].lemma_ == "egg"
 | |
| 
 | |
|     # Check model after a {to,from}_bytes roundtrip
 | |
|     nlp_bytes = nlp.to_bytes()
 | |
|     nlp3 = English()
 | |
|     nlp3.add_pipe("trainable_lemmatizer")
 | |
|     nlp3.from_bytes(nlp_bytes)
 | |
|     doc3 = nlp3(test_text)
 | |
|     assert doc3[0].lemma_ == "she"
 | |
|     assert doc3[1].lemma_ == "like"
 | |
|     assert doc3[2].lemma_ == "blue"
 | |
|     assert doc3[3].lemma_ == "egg"
 | |
| 
 | |
|     # Check model after a pickle roundtrip.
 | |
|     nlp_bytes = pickle.dumps(nlp)
 | |
|     nlp4 = pickle.loads(nlp_bytes)
 | |
|     doc4 = nlp4(test_text)
 | |
|     assert doc4[0].lemma_ == "she"
 | |
|     assert doc4[1].lemma_ == "like"
 | |
|     assert doc4[2].lemma_ == "blue"
 | |
|     assert doc4[3].lemma_ == "egg"
 | |
| 
 | |
| 
 | |
| def test_is_distillable():
 | |
|     nlp = English()
 | |
|     lemmatizer = nlp.add_pipe("trainable_lemmatizer")
 | |
|     assert lemmatizer.is_distillable
 | |
| 
 | |
| 
 | |
| def test_distill():
 | |
|     teacher = English()
 | |
|     teacher_lemmatizer = teacher.add_pipe("trainable_lemmatizer")
 | |
|     teacher_lemmatizer.min_tree_freq = 1
 | |
|     train_examples = []
 | |
|     for t in TRAIN_DATA:
 | |
|         train_examples.append(Example.from_dict(teacher.make_doc(t[0]), t[1]))
 | |
| 
 | |
|     optimizer = teacher.initialize(get_examples=lambda: train_examples)
 | |
| 
 | |
|     for i in range(50):
 | |
|         losses = {}
 | |
|         teacher.update(train_examples, sgd=optimizer, losses=losses)
 | |
|     assert losses["trainable_lemmatizer"] < 0.00001
 | |
| 
 | |
|     student = English()
 | |
|     student_lemmatizer = student.add_pipe("trainable_lemmatizer")
 | |
|     student_lemmatizer.min_tree_freq = 1
 | |
|     student_lemmatizer.initialize(
 | |
|         get_examples=lambda: train_examples, labels=teacher_lemmatizer.label_data
 | |
|     )
 | |
| 
 | |
|     distill_examples = [
 | |
|         Example.from_dict(teacher.make_doc(t[0]), {}) for t in TRAIN_DATA
 | |
|     ]
 | |
| 
 | |
|     for i in range(50):
 | |
|         losses = {}
 | |
|         student_lemmatizer.distill(
 | |
|             teacher_lemmatizer, distill_examples, sgd=optimizer, losses=losses
 | |
|         )
 | |
|     assert losses["trainable_lemmatizer"] < 0.00001
 | |
| 
 | |
|     test_text = "She likes blue eggs"
 | |
|     doc = student(test_text)
 | |
|     assert doc[0].lemma_ == "she"
 | |
|     assert doc[1].lemma_ == "like"
 | |
|     assert doc[2].lemma_ == "blue"
 | |
|     assert doc[3].lemma_ == "egg"
 | |
| 
 | |
| 
 | |
| def test_lemmatizer_requires_labels():
 | |
|     nlp = English()
 | |
|     nlp.add_pipe("trainable_lemmatizer")
 | |
|     with pytest.raises(ValueError):
 | |
|         nlp.initialize()
 | |
| 
 | |
| 
 | |
| def test_lemmatizer_label_data():
 | |
|     nlp = English()
 | |
|     lemmatizer = nlp.add_pipe("trainable_lemmatizer")
 | |
|     lemmatizer.min_tree_freq = 1
 | |
|     train_examples = []
 | |
|     for t in TRAIN_DATA:
 | |
|         train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
 | |
| 
 | |
|     nlp.initialize(get_examples=lambda: train_examples)
 | |
| 
 | |
|     nlp2 = English()
 | |
|     lemmatizer2 = nlp2.add_pipe("trainable_lemmatizer")
 | |
|     lemmatizer2.initialize(
 | |
|         get_examples=lambda: train_examples, labels=lemmatizer.label_data
 | |
|     )
 | |
| 
 | |
|     # Verify that the labels and trees are the same.
 | |
|     assert lemmatizer.labels == lemmatizer2.labels
 | |
|     assert lemmatizer.trees.to_bytes() == lemmatizer2.trees.to_bytes()
 | |
| 
 | |
| 
 | |
| def test_dutch():
 | |
|     strings = StringStore()
 | |
|     trees = EditTrees(strings)
 | |
|     tree = trees.add("deelt", "delen")
 | |
|     assert trees.tree_to_str(tree) == "(m 0 3 () (m 0 2 (s '' 'l') (s 'lt' 'n')))"
 | |
| 
 | |
|     tree = trees.add("gedeeld", "delen")
 | |
|     assert (
 | |
|         trees.tree_to_str(tree) == "(m 2 3 (s 'ge' '') (m 0 2 (s '' 'l') (s 'ld' 'n')))"
 | |
|     )
 | |
| 
 | |
| 
 | |
| def test_from_to_bytes():
 | |
|     strings = StringStore()
 | |
|     trees = EditTrees(strings)
 | |
|     trees.add("deelt", "delen")
 | |
|     trees.add("gedeeld", "delen")
 | |
| 
 | |
|     b = trees.to_bytes()
 | |
| 
 | |
|     trees2 = EditTrees(strings)
 | |
|     trees2.from_bytes(b)
 | |
| 
 | |
|     # Verify that the nodes did not change.
 | |
|     assert len(trees) == len(trees2)
 | |
|     for i in range(len(trees)):
 | |
|         assert trees.tree_to_str(i) == trees2.tree_to_str(i)
 | |
| 
 | |
|     # Reinserting the same trees should not add new nodes.
 | |
|     trees2.add("deelt", "delen")
 | |
|     trees2.add("gedeeld", "delen")
 | |
|     assert len(trees) == len(trees2)
 | |
| 
 | |
| 
 | |
| def test_from_to_disk():
 | |
|     strings = StringStore()
 | |
|     trees = EditTrees(strings)
 | |
|     trees.add("deelt", "delen")
 | |
|     trees.add("gedeeld", "delen")
 | |
| 
 | |
|     trees2 = EditTrees(strings)
 | |
|     with make_tempdir() as temp_dir:
 | |
|         trees_file = temp_dir / "edit_trees.bin"
 | |
|         trees.to_disk(trees_file)
 | |
|         trees2 = trees2.from_disk(trees_file)
 | |
| 
 | |
|     # Verify that the nodes did not change.
 | |
|     assert len(trees) == len(trees2)
 | |
|     for i in range(len(trees)):
 | |
|         assert trees.tree_to_str(i) == trees2.tree_to_str(i)
 | |
| 
 | |
|     # Reinserting the same trees should not add new nodes.
 | |
|     trees2.add("deelt", "delen")
 | |
|     trees2.add("gedeeld", "delen")
 | |
|     assert len(trees) == len(trees2)
 | |
| 
 | |
| 
 | |
| @given(st.text(), st.text())
 | |
| def test_roundtrip(form, lemma):
 | |
|     strings = StringStore()
 | |
|     trees = EditTrees(strings)
 | |
|     tree = trees.add(form, lemma)
 | |
|     assert trees.apply(tree, form) == lemma
 | |
| 
 | |
| 
 | |
| @given(st.text(alphabet="ab"), st.text(alphabet="ab"))
 | |
| def test_roundtrip_small_alphabet(form, lemma):
 | |
|     # Test with small alphabets to have more overlap.
 | |
|     strings = StringStore()
 | |
|     trees = EditTrees(strings)
 | |
|     tree = trees.add(form, lemma)
 | |
|     assert trees.apply(tree, form) == lemma
 | |
| 
 | |
| 
 | |
| def test_unapplicable_trees():
 | |
|     strings = StringStore()
 | |
|     trees = EditTrees(strings)
 | |
|     tree3 = trees.add("deelt", "delen")
 | |
| 
 | |
|     # Replacement fails.
 | |
|     assert trees.apply(tree3, "deeld") == None
 | |
| 
 | |
|     # Suffix + prefix are too large.
 | |
|     assert trees.apply(tree3, "de") == None
 | |
| 
 | |
| 
 | |
| def test_empty_strings():
 | |
|     strings = StringStore()
 | |
|     trees = EditTrees(strings)
 | |
|     no_change = trees.add("xyz", "xyz")
 | |
|     empty = trees.add("", "")
 | |
|     assert no_change == empty
 | |
| 
 | |
| 
 | |
| def test_save_activations():
 | |
|     nlp = English()
 | |
|     lemmatizer = cast(TrainablePipe, nlp.add_pipe("trainable_lemmatizer"))
 | |
|     lemmatizer.min_tree_freq = 1
 | |
|     train_examples = []
 | |
|     for t in TRAIN_DATA:
 | |
|         train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
 | |
|     nlp.initialize(get_examples=lambda: train_examples)
 | |
|     nO = lemmatizer.model.get_dim("nO")
 | |
| 
 | |
|     doc = nlp("This is a test.")
 | |
|     assert "trainable_lemmatizer" not in doc.activations
 | |
| 
 | |
|     lemmatizer.save_activations = True
 | |
|     doc = nlp("This is a test.")
 | |
|     assert list(doc.activations["trainable_lemmatizer"].keys()) == [
 | |
|         "probabilities",
 | |
|         "tree_ids",
 | |
|     ]
 | |
|     assert doc.activations["trainable_lemmatizer"]["probabilities"].shape == (5, nO)
 | |
|     assert doc.activations["trainable_lemmatizer"]["tree_ids"].shape == (5,)
 |