Add Pipe.is_distillable method

This commit is contained in:
Daniël de Kok 2023-01-13 09:48:11 +01:00
parent dd83157594
commit 1fb095a79b
9 changed files with 50 additions and 0 deletions

View File

@ -87,6 +87,10 @@ cdef class Pipe:
return self.scorer(examples, **scorer_kwargs) return self.scorer(examples, **scorer_kwargs)
return {} return {}
@property
def is_distillable(self) -> bool:
return False
@property @property
def is_trainable(self) -> bool: def is_trainable(self) -> bool:
return False return False

View File

@ -264,6 +264,10 @@ cdef class TrainablePipe(Pipe):
""" """
raise NotImplementedError(Errors.E931.format(parent="Pipe", method="add_label", name=self.name)) raise NotImplementedError(Errors.E931.format(parent="Pipe", method="add_label", name=self.name))
@property
def is_distillable(self) -> bool:
return not (self.__class__.distill is TrainablePipe.distill and self.__class__.get_teacher_student_loss is TrainablePipe.get_teacher_student_loss)
@property @property
def is_trainable(self) -> bool: def is_trainable(self) -> bool:
return True return True

View File

@ -617,6 +617,12 @@ def test_overfitting_IO(use_upper):
assert ents[1].kb_id == 0 assert ents[1].kb_id == 0
def test_is_distillable():
nlp = English()
ner = nlp.add_pipe("ner")
assert ner.is_distillable
def test_distill(): def test_distill():
teacher = English() teacher = English()
teacher_ner = teacher.add_pipe("ner") teacher_ner = teacher.add_pipe("ner")

View File

@ -396,6 +396,12 @@ def test_overfitting_IO(pipe_name):
assert_equal(batch_deps_1, no_batch_deps) assert_equal(batch_deps_1, no_batch_deps)
def test_is_distillable():
nlp = English()
parser = nlp.add_pipe("parser")
assert parser.is_distillable
def test_distill(): def test_distill():
teacher = English() teacher = English()
teacher_parser = teacher.add_pipe("parser") teacher_parser = teacher.add_pipe("parser")

View File

@ -195,6 +195,12 @@ def test_overfitting_IO():
assert doc4[3].lemma_ == "egg" assert doc4[3].lemma_ == "egg"
def test_is_distillable():
nlp = English()
lemmatizer = nlp.add_pipe("trainable_lemmatizer")
assert lemmatizer.is_distillable
def test_distill(): def test_distill():
teacher = English() teacher = English()
teacher_lemmatizer = teacher.add_pipe("trainable_lemmatizer") teacher_lemmatizer = teacher.add_pipe("trainable_lemmatizer")

View File

@ -50,6 +50,12 @@ def test_implicit_label():
nlp.initialize(get_examples=lambda: train_examples) nlp.initialize(get_examples=lambda: train_examples)
def test_is_distillable():
nlp = English()
morphologizer = nlp.add_pipe("morphologizer")
assert morphologizer.is_distillable
def test_no_resize(): def test_no_resize():
nlp = Language() nlp = Language()
morphologizer = nlp.add_pipe("morphologizer") morphologizer = nlp.add_pipe("morphologizer")

View File

@ -11,6 +11,12 @@ from spacy.pipeline import TrainablePipe
from spacy.tests.util import make_tempdir from spacy.tests.util import make_tempdir
def test_is_distillable():
nlp = English()
senter = nlp.add_pipe("senter")
assert senter.is_distillable
def test_label_types(): def test_label_types():
nlp = Language() nlp = Language()
senter = nlp.add_pipe("senter") senter = nlp.add_pipe("senter")

View File

@ -213,6 +213,12 @@ def test_overfitting_IO():
assert doc3[0].tag_ != "N" assert doc3[0].tag_ != "N"
def test_is_distillable():
nlp = English()
tagger = nlp.add_pipe("tagger")
assert tagger.is_distillable
def test_distill(): def test_distill():
teacher = English() teacher = English()
teacher_tagger = teacher.add_pipe("tagger") teacher_tagger = teacher.add_pipe("tagger")

View File

@ -565,6 +565,12 @@ def test_initialize_examples(name, get_examples, train_data):
nlp.initialize(get_examples=get_examples()) nlp.initialize(get_examples=get_examples())
def test_is_distillable():
nlp = English()
textcat = nlp.add_pipe("textcat")
assert not textcat.is_distillable
def test_overfitting_IO(): def test_overfitting_IO():
# Simple test to try and quickly overfit the single-label textcat component - ensuring the ML models work correctly # Simple test to try and quickly overfit the single-label textcat component - ensuring the ML models work correctly
fix_random_seed(0) fix_random_seed(0)