mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-04 20:30:24 +03:00
Add Pipe.is_distillable method
This commit is contained in:
parent
dd83157594
commit
1fb095a79b
|
@ -87,6 +87,10 @@ cdef class Pipe:
|
|||
return self.scorer(examples, **scorer_kwargs)
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_distillable(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_trainable(self) -> bool:
|
||||
return False
|
||||
|
|
|
@ -264,6 +264,10 @@ cdef class TrainablePipe(Pipe):
|
|||
"""
|
||||
raise NotImplementedError(Errors.E931.format(parent="Pipe", method="add_label", name=self.name))
|
||||
|
||||
@property
|
||||
def is_distillable(self) -> bool:
|
||||
return not (self.__class__.distill is TrainablePipe.distill and self.__class__.get_teacher_student_loss is TrainablePipe.get_teacher_student_loss)
|
||||
|
||||
@property
|
||||
def is_trainable(self) -> bool:
|
||||
return True
|
||||
|
|
|
@ -617,6 +617,12 @@ def test_overfitting_IO(use_upper):
|
|||
assert ents[1].kb_id == 0
|
||||
|
||||
|
||||
def test_is_distillable():
|
||||
nlp = English()
|
||||
ner = nlp.add_pipe("ner")
|
||||
assert ner.is_distillable
|
||||
|
||||
|
||||
def test_distill():
|
||||
teacher = English()
|
||||
teacher_ner = teacher.add_pipe("ner")
|
||||
|
|
|
@ -396,6 +396,12 @@ def test_overfitting_IO(pipe_name):
|
|||
assert_equal(batch_deps_1, no_batch_deps)
|
||||
|
||||
|
||||
def test_is_distillable():
|
||||
nlp = English()
|
||||
parser = nlp.add_pipe("parser")
|
||||
assert parser.is_distillable
|
||||
|
||||
|
||||
def test_distill():
|
||||
teacher = English()
|
||||
teacher_parser = teacher.add_pipe("parser")
|
||||
|
|
|
@ -195,6 +195,12 @@ def test_overfitting_IO():
|
|||
assert doc4[3].lemma_ == "egg"
|
||||
|
||||
|
||||
def test_is_distillable():
|
||||
nlp = English()
|
||||
lemmatizer = nlp.add_pipe("trainable_lemmatizer")
|
||||
assert lemmatizer.is_distillable
|
||||
|
||||
|
||||
def test_distill():
|
||||
teacher = English()
|
||||
teacher_lemmatizer = teacher.add_pipe("trainable_lemmatizer")
|
||||
|
|
|
@ -50,6 +50,12 @@ def test_implicit_label():
|
|||
nlp.initialize(get_examples=lambda: train_examples)
|
||||
|
||||
|
||||
def test_is_distillable():
|
||||
nlp = English()
|
||||
morphologizer = nlp.add_pipe("morphologizer")
|
||||
assert morphologizer.is_distillable
|
||||
|
||||
|
||||
def test_no_resize():
|
||||
nlp = Language()
|
||||
morphologizer = nlp.add_pipe("morphologizer")
|
||||
|
|
|
@ -11,6 +11,12 @@ from spacy.pipeline import TrainablePipe
|
|||
from spacy.tests.util import make_tempdir
|
||||
|
||||
|
||||
def test_is_distillable():
|
||||
nlp = English()
|
||||
senter = nlp.add_pipe("senter")
|
||||
assert senter.is_distillable
|
||||
|
||||
|
||||
def test_label_types():
|
||||
nlp = Language()
|
||||
senter = nlp.add_pipe("senter")
|
||||
|
|
|
@ -213,6 +213,12 @@ def test_overfitting_IO():
|
|||
assert doc3[0].tag_ != "N"
|
||||
|
||||
|
||||
def test_is_distillable():
|
||||
nlp = English()
|
||||
tagger = nlp.add_pipe("tagger")
|
||||
assert tagger.is_distillable
|
||||
|
||||
|
||||
def test_distill():
|
||||
teacher = English()
|
||||
teacher_tagger = teacher.add_pipe("tagger")
|
||||
|
|
|
@ -565,6 +565,12 @@ def test_initialize_examples(name, get_examples, train_data):
|
|||
nlp.initialize(get_examples=get_examples())
|
||||
|
||||
|
||||
def test_is_distillable():
|
||||
nlp = English()
|
||||
textcat = nlp.add_pipe("textcat")
|
||||
assert not textcat.is_distillable
|
||||
|
||||
|
||||
def test_overfitting_IO():
|
||||
# Simple test to try and quickly overfit the single-label textcat component - ensuring the ML models work correctly
|
||||
fix_random_seed(0)
|
||||
|
|
Loading…
Reference in New Issue
Block a user