Add validate_distillation_examples

This first calls `validate_examples` and then checks that the
student/teacher tokens are the same.
This commit is contained in:
Daniël de Kok 2023-01-13 10:32:14 +01:00
parent 1fb095a79b
commit d54cc5245a
6 changed files with 28 additions and 4 deletions

View File

@ -956,6 +956,8 @@ class Errors(metaclass=ErrorsWithCodes):
E4001 = ("Expected input to be one of the following types: ({expected_types}), "
"but got '{received_type}'")
E4002 = ("Pipe '{name}' requires a teacher pipe for distillation.")
E4003 = ("Training examples for distillation must have the exact same tokens in the "
"reference and predicted docs.")
# fmt: on

View File

@ -6,7 +6,7 @@ import warnings
from ..tokens.doc cimport Doc
from ..training import validate_examples
from ..training import validate_examples, validate_distillation_examples
from ..errors import Errors, Warnings
from .pipe import Pipe, deserialize_config
from .. import util
@ -88,7 +88,7 @@ cdef class TrainablePipe(Pipe):
if losses is None:
losses = {}
losses.setdefault(self.name, 0.0)
validate_examples(examples, "TrainablePipe.distill")
validate_distillation_examples(examples, "TrainablePipe.distill")
set_dropout_rate(self.model, drop)
for node in teacher_pipe.model.walk():
if node.name == "softmax":

View File

@ -30,6 +30,7 @@ from ._parser_internals cimport _beam_utils
from ._parser_internals import _beam_utils
from ..training import validate_examples, validate_get_examples
from ..training import validate_distillation_examples
from ..errors import Errors, Warnings
from .. import util
@ -236,7 +237,7 @@ cdef class Parser(TrainablePipe):
losses = {}
losses.setdefault(self.name, 0.0)
validate_examples(examples, "TransitionParser.distill")
validate_distillation_examples(examples, "TransitionParser.distill")
set_dropout_rate(self.model, drop)

View File

@ -8,7 +8,7 @@ from spacy.lang.en import English
from spacy.tokens import Doc, DocBin
from spacy.training import Alignment, Corpus, Example, biluo_tags_to_offsets
from spacy.training import biluo_tags_to_spans, docs_to_json, iob_to_biluo
from spacy.training import offsets_to_biluo_tags
from spacy.training import offsets_to_biluo_tags, validate_distillation_examples
from spacy.training.alignment_array import AlignmentArray
from spacy.training.align import get_alignments
from spacy.training.converters import json_to_docs
@ -365,6 +365,19 @@ def test_example_from_dict_some_ner(en_vocab):
assert ner_tags == ["U-LOC", None, None, None]
def test_validate_distillation_examples(en_vocab):
words = ["a", "b", "c", "d"]
spaces = [True, True, False, True]
predicted = Doc(en_vocab, words=words, spaces=spaces)
example = Example.from_dict(predicted, {})
validate_distillation_examples([example], "test_validate_distillation_examples")
example = Example.from_dict(predicted, {"words": words + ["e"]})
with pytest.raises(ValueError, match=r"distillation"):
validate_distillation_examples([example], "test_validate_distillation_examples")
@pytest.mark.filterwarnings("ignore::UserWarning")
def test_json_to_docs_no_ner(en_vocab):
data = [

View File

@ -1,5 +1,6 @@
from .corpus import Corpus, JsonlCorpus # noqa: F401
from .example import Example, validate_examples, validate_get_examples # noqa: F401
from .example import validate_distillation_examples # noqa: F401
from .alignment import Alignment # noqa: F401
from .augment import dont_augment, orth_variants_augmenter # noqa: F401
from .iob_utils import iob_to_biluo, biluo_to_iob # noqa: F401

View File

@ -47,6 +47,13 @@ def validate_examples(examples, method):
raise TypeError(err)
def validate_distillation_examples(examples, method):
validate_examples(examples, method)
for eg in examples:
if [token.text for token in eg.reference] != [token.text for token in eg.predicted]:
raise ValueError(Errors.E4003)
def validate_get_examples(get_examples, method):
"""Check that a generator of a batch of examples received during processing is valid:
the callable produces a non-empty list of Example objects.