From d54cc5245a5c59e057d5ae61482b4378491eada5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Fri, 13 Jan 2023 10:32:14 +0100 Subject: [PATCH] Add `validate_distillation_examples` This first calls `validate_examples` and then checks that the student/teacher tokens are the same. --- spacy/errors.py | 2 ++ spacy/pipeline/trainable_pipe.pyx | 4 ++-- spacy/pipeline/transition_parser.pyx | 3 ++- spacy/tests/training/test_training.py | 15 ++++++++++++++- spacy/training/__init__.py | 1 + spacy/training/example.pyx | 7 +++++++ 6 files changed, 28 insertions(+), 4 deletions(-) diff --git a/spacy/errors.py b/spacy/errors.py index 1a1e9ea10..91e7a6006 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -956,6 +956,8 @@ class Errors(metaclass=ErrorsWithCodes): E4001 = ("Expected input to be one of the following types: ({expected_types}), " "but got '{received_type}'") E4002 = ("Pipe '{name}' requires a teacher pipe for distillation.") + E4003 = ("Training examples for distillation must have the exact same tokens in the " + "reference and predicted docs.") # fmt: on diff --git a/spacy/pipeline/trainable_pipe.pyx b/spacy/pipeline/trainable_pipe.pyx index fbcbe7a17..e3944592c 100644 --- a/spacy/pipeline/trainable_pipe.pyx +++ b/spacy/pipeline/trainable_pipe.pyx @@ -6,7 +6,7 @@ import warnings from ..tokens.doc cimport Doc -from ..training import validate_examples +from ..training import validate_examples, validate_distillation_examples from ..errors import Errors, Warnings from .pipe import Pipe, deserialize_config from .. import util @@ -88,7 +88,7 @@ cdef class TrainablePipe(Pipe): if losses is None: losses = {} losses.setdefault(self.name, 0.0) - validate_examples(examples, "TrainablePipe.distill") + validate_distillation_examples(examples, "TrainablePipe.distill") set_dropout_rate(self.model, drop) for node in teacher_pipe.model.walk(): if node.name == "softmax": diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index 44decbf5f..7ba2641ae 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -30,6 +30,7 @@ from ._parser_internals cimport _beam_utils from ._parser_internals import _beam_utils from ..training import validate_examples, validate_get_examples +from ..training import validate_distillation_examples from ..errors import Errors, Warnings from .. import util @@ -236,7 +237,7 @@ cdef class Parser(TrainablePipe): losses = {} losses.setdefault(self.name, 0.0) - validate_examples(examples, "TransitionParser.distill") + validate_distillation_examples(examples, "TransitionParser.distill") set_dropout_rate(self.model, drop) diff --git a/spacy/tests/training/test_training.py b/spacy/tests/training/test_training.py index 7933ea31f..9fdd416b1 100644 --- a/spacy/tests/training/test_training.py +++ b/spacy/tests/training/test_training.py @@ -8,7 +8,7 @@ from spacy.lang.en import English from spacy.tokens import Doc, DocBin from spacy.training import Alignment, Corpus, Example, biluo_tags_to_offsets from spacy.training import biluo_tags_to_spans, docs_to_json, iob_to_biluo -from spacy.training import offsets_to_biluo_tags +from spacy.training import offsets_to_biluo_tags, validate_distillation_examples from spacy.training.alignment_array import AlignmentArray from spacy.training.align import get_alignments from spacy.training.converters import json_to_docs @@ -365,6 +365,19 @@ def test_example_from_dict_some_ner(en_vocab): assert ner_tags == ["U-LOC", None, None, None] +def test_validate_distillation_examples(en_vocab): + words = ["a", "b", "c", "d"] + spaces = [True, True, False, True] + predicted = Doc(en_vocab, words=words, spaces=spaces) + + example = Example.from_dict(predicted, {}) + validate_distillation_examples([example], "test_validate_distillation_examples") + + example = Example.from_dict(predicted, {"words": words + ["e"]}) + with pytest.raises(ValueError, match=r"distillation"): + validate_distillation_examples([example], "test_validate_distillation_examples") + + @pytest.mark.filterwarnings("ignore::UserWarning") def test_json_to_docs_no_ner(en_vocab): data = [ diff --git a/spacy/training/__init__.py b/spacy/training/__init__.py index 71d1fa775..454437104 100644 --- a/spacy/training/__init__.py +++ b/spacy/training/__init__.py @@ -1,5 +1,6 @@ from .corpus import Corpus, JsonlCorpus # noqa: F401 from .example import Example, validate_examples, validate_get_examples # noqa: F401 +from .example import validate_distillation_examples # noqa: F401 from .alignment import Alignment # noqa: F401 from .augment import dont_augment, orth_variants_augmenter # noqa: F401 from .iob_utils import iob_to_biluo, biluo_to_iob # noqa: F401 diff --git a/spacy/training/example.pyx b/spacy/training/example.pyx index 95b0f0de9..d6f3a07fb 100644 --- a/spacy/training/example.pyx +++ b/spacy/training/example.pyx @@ -47,6 +47,13 @@ def validate_examples(examples, method): raise TypeError(err) +def validate_distillation_examples(examples, method): + validate_examples(examples, method) + for eg in examples: + if [token.text for token in eg.reference] != [token.text for token in eg.predicted]: + raise ValueError(Errors.E4003) + + def validate_get_examples(get_examples, method): """Check that a generator of a batch of examples received during processing is valid: the callable produces a non-empty list of Example objects.