mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-03 20:00:21 +03:00
Add validate_distillation_examples
This first calls `validate_examples` and then checks that the student/teacher tokens are the same.
This commit is contained in:
parent
1fb095a79b
commit
d54cc5245a
|
@ -956,6 +956,8 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
E4001 = ("Expected input to be one of the following types: ({expected_types}), "
|
||||
"but got '{received_type}'")
|
||||
E4002 = ("Pipe '{name}' requires a teacher pipe for distillation.")
|
||||
E4003 = ("Training examples for distillation must have the exact same tokens in the "
|
||||
"reference and predicted docs.")
|
||||
|
||||
|
||||
# fmt: on
|
||||
|
|
|
@ -6,7 +6,7 @@ import warnings
|
|||
|
||||
from ..tokens.doc cimport Doc
|
||||
|
||||
from ..training import validate_examples
|
||||
from ..training import validate_examples, validate_distillation_examples
|
||||
from ..errors import Errors, Warnings
|
||||
from .pipe import Pipe, deserialize_config
|
||||
from .. import util
|
||||
|
@ -88,7 +88,7 @@ cdef class TrainablePipe(Pipe):
|
|||
if losses is None:
|
||||
losses = {}
|
||||
losses.setdefault(self.name, 0.0)
|
||||
validate_examples(examples, "TrainablePipe.distill")
|
||||
validate_distillation_examples(examples, "TrainablePipe.distill")
|
||||
set_dropout_rate(self.model, drop)
|
||||
for node in teacher_pipe.model.walk():
|
||||
if node.name == "softmax":
|
||||
|
|
|
@ -30,6 +30,7 @@ from ._parser_internals cimport _beam_utils
|
|||
from ._parser_internals import _beam_utils
|
||||
|
||||
from ..training import validate_examples, validate_get_examples
|
||||
from ..training import validate_distillation_examples
|
||||
from ..errors import Errors, Warnings
|
||||
from .. import util
|
||||
|
||||
|
@ -236,7 +237,7 @@ cdef class Parser(TrainablePipe):
|
|||
losses = {}
|
||||
losses.setdefault(self.name, 0.0)
|
||||
|
||||
validate_examples(examples, "TransitionParser.distill")
|
||||
validate_distillation_examples(examples, "TransitionParser.distill")
|
||||
|
||||
set_dropout_rate(self.model, drop)
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from spacy.lang.en import English
|
|||
from spacy.tokens import Doc, DocBin
|
||||
from spacy.training import Alignment, Corpus, Example, biluo_tags_to_offsets
|
||||
from spacy.training import biluo_tags_to_spans, docs_to_json, iob_to_biluo
|
||||
from spacy.training import offsets_to_biluo_tags
|
||||
from spacy.training import offsets_to_biluo_tags, validate_distillation_examples
|
||||
from spacy.training.alignment_array import AlignmentArray
|
||||
from spacy.training.align import get_alignments
|
||||
from spacy.training.converters import json_to_docs
|
||||
|
@ -365,6 +365,19 @@ def test_example_from_dict_some_ner(en_vocab):
|
|||
assert ner_tags == ["U-LOC", None, None, None]
|
||||
|
||||
|
||||
def test_validate_distillation_examples(en_vocab):
|
||||
words = ["a", "b", "c", "d"]
|
||||
spaces = [True, True, False, True]
|
||||
predicted = Doc(en_vocab, words=words, spaces=spaces)
|
||||
|
||||
example = Example.from_dict(predicted, {})
|
||||
validate_distillation_examples([example], "test_validate_distillation_examples")
|
||||
|
||||
example = Example.from_dict(predicted, {"words": words + ["e"]})
|
||||
with pytest.raises(ValueError, match=r"distillation"):
|
||||
validate_distillation_examples([example], "test_validate_distillation_examples")
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::UserWarning")
|
||||
def test_json_to_docs_no_ner(en_vocab):
|
||||
data = [
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from .corpus import Corpus, JsonlCorpus # noqa: F401
|
||||
from .example import Example, validate_examples, validate_get_examples # noqa: F401
|
||||
from .example import validate_distillation_examples # noqa: F401
|
||||
from .alignment import Alignment # noqa: F401
|
||||
from .augment import dont_augment, orth_variants_augmenter # noqa: F401
|
||||
from .iob_utils import iob_to_biluo, biluo_to_iob # noqa: F401
|
||||
|
|
|
@ -47,6 +47,13 @@ def validate_examples(examples, method):
|
|||
raise TypeError(err)
|
||||
|
||||
|
||||
def validate_distillation_examples(examples, method):
|
||||
validate_examples(examples, method)
|
||||
for eg in examples:
|
||||
if [token.text for token in eg.reference] != [token.text for token in eg.predicted]:
|
||||
raise ValueError(Errors.E4003)
|
||||
|
||||
|
||||
def validate_get_examples(get_examples, method):
|
||||
"""Check that a generator of a batch of examples received during processing is valid:
|
||||
the callable produces a non-empty list of Example objects.
|
||||
|
|
Loading…
Reference in New Issue
Block a user