mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-04 20:30:24 +03:00
Add Pipe.is_distillable method
This commit is contained in:
parent
dd83157594
commit
1fb095a79b
|
@ -87,6 +87,10 @@ cdef class Pipe:
|
||||||
return self.scorer(examples, **scorer_kwargs)
|
return self.scorer(examples, **scorer_kwargs)
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_distillable(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_trainable(self) -> bool:
|
def is_trainable(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -264,6 +264,10 @@ cdef class TrainablePipe(Pipe):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError(Errors.E931.format(parent="Pipe", method="add_label", name=self.name))
|
raise NotImplementedError(Errors.E931.format(parent="Pipe", method="add_label", name=self.name))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_distillable(self) -> bool:
|
||||||
|
return not (self.__class__.distill is TrainablePipe.distill and self.__class__.get_teacher_student_loss is TrainablePipe.get_teacher_student_loss)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_trainable(self) -> bool:
|
def is_trainable(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -617,6 +617,12 @@ def test_overfitting_IO(use_upper):
|
||||||
assert ents[1].kb_id == 0
|
assert ents[1].kb_id == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_distillable():
|
||||||
|
nlp = English()
|
||||||
|
ner = nlp.add_pipe("ner")
|
||||||
|
assert ner.is_distillable
|
||||||
|
|
||||||
|
|
||||||
def test_distill():
|
def test_distill():
|
||||||
teacher = English()
|
teacher = English()
|
||||||
teacher_ner = teacher.add_pipe("ner")
|
teacher_ner = teacher.add_pipe("ner")
|
||||||
|
|
|
@ -396,6 +396,12 @@ def test_overfitting_IO(pipe_name):
|
||||||
assert_equal(batch_deps_1, no_batch_deps)
|
assert_equal(batch_deps_1, no_batch_deps)
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_distillable():
|
||||||
|
nlp = English()
|
||||||
|
parser = nlp.add_pipe("parser")
|
||||||
|
assert parser.is_distillable
|
||||||
|
|
||||||
|
|
||||||
def test_distill():
|
def test_distill():
|
||||||
teacher = English()
|
teacher = English()
|
||||||
teacher_parser = teacher.add_pipe("parser")
|
teacher_parser = teacher.add_pipe("parser")
|
||||||
|
|
|
@ -195,6 +195,12 @@ def test_overfitting_IO():
|
||||||
assert doc4[3].lemma_ == "egg"
|
assert doc4[3].lemma_ == "egg"
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_distillable():
|
||||||
|
nlp = English()
|
||||||
|
lemmatizer = nlp.add_pipe("trainable_lemmatizer")
|
||||||
|
assert lemmatizer.is_distillable
|
||||||
|
|
||||||
|
|
||||||
def test_distill():
|
def test_distill():
|
||||||
teacher = English()
|
teacher = English()
|
||||||
teacher_lemmatizer = teacher.add_pipe("trainable_lemmatizer")
|
teacher_lemmatizer = teacher.add_pipe("trainable_lemmatizer")
|
||||||
|
|
|
@ -50,6 +50,12 @@ def test_implicit_label():
|
||||||
nlp.initialize(get_examples=lambda: train_examples)
|
nlp.initialize(get_examples=lambda: train_examples)
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_distillable():
|
||||||
|
nlp = English()
|
||||||
|
morphologizer = nlp.add_pipe("morphologizer")
|
||||||
|
assert morphologizer.is_distillable
|
||||||
|
|
||||||
|
|
||||||
def test_no_resize():
|
def test_no_resize():
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
morphologizer = nlp.add_pipe("morphologizer")
|
morphologizer = nlp.add_pipe("morphologizer")
|
||||||
|
|
|
@ -11,6 +11,12 @@ from spacy.pipeline import TrainablePipe
|
||||||
from spacy.tests.util import make_tempdir
|
from spacy.tests.util import make_tempdir
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_distillable():
|
||||||
|
nlp = English()
|
||||||
|
senter = nlp.add_pipe("senter")
|
||||||
|
assert senter.is_distillable
|
||||||
|
|
||||||
|
|
||||||
def test_label_types():
|
def test_label_types():
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
senter = nlp.add_pipe("senter")
|
senter = nlp.add_pipe("senter")
|
||||||
|
|
|
@ -213,6 +213,12 @@ def test_overfitting_IO():
|
||||||
assert doc3[0].tag_ != "N"
|
assert doc3[0].tag_ != "N"
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_distillable():
|
||||||
|
nlp = English()
|
||||||
|
tagger = nlp.add_pipe("tagger")
|
||||||
|
assert tagger.is_distillable
|
||||||
|
|
||||||
|
|
||||||
def test_distill():
|
def test_distill():
|
||||||
teacher = English()
|
teacher = English()
|
||||||
teacher_tagger = teacher.add_pipe("tagger")
|
teacher_tagger = teacher.add_pipe("tagger")
|
||||||
|
|
|
@ -565,6 +565,12 @@ def test_initialize_examples(name, get_examples, train_data):
|
||||||
nlp.initialize(get_examples=get_examples())
|
nlp.initialize(get_examples=get_examples())
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_distillable():
|
||||||
|
nlp = English()
|
||||||
|
textcat = nlp.add_pipe("textcat")
|
||||||
|
assert not textcat.is_distillable
|
||||||
|
|
||||||
|
|
||||||
def test_overfitting_IO():
|
def test_overfitting_IO():
|
||||||
# Simple test to try and quickly overfit the single-label textcat component - ensuring the ML models work correctly
|
# Simple test to try and quickly overfit the single-label textcat component - ensuring the ML models work correctly
|
||||||
fix_random_seed(0)
|
fix_random_seed(0)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user