From 1fb095a79b0f9eabd31567b29c50c3e4c5d7b6fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Fri, 13 Jan 2023 09:48:11 +0100 Subject: [PATCH] Add Pipe.is_distillable method --- spacy/pipeline/pipe.pyx | 4 ++++ spacy/pipeline/trainable_pipe.pyx | 4 ++++ spacy/tests/parser/test_ner.py | 6 ++++++ spacy/tests/parser/test_parse.py | 6 ++++++ spacy/tests/pipeline/test_edit_tree_lemmatizer.py | 6 ++++++ spacy/tests/pipeline/test_morphologizer.py | 6 ++++++ spacy/tests/pipeline/test_senter.py | 6 ++++++ spacy/tests/pipeline/test_tagger.py | 6 ++++++ spacy/tests/pipeline/test_textcat.py | 6 ++++++ 9 files changed, 50 insertions(+) diff --git a/spacy/pipeline/pipe.pyx b/spacy/pipeline/pipe.pyx index c5650382b..8b8fdc361 100644 --- a/spacy/pipeline/pipe.pyx +++ b/spacy/pipeline/pipe.pyx @@ -87,6 +87,10 @@ cdef class Pipe: return self.scorer(examples, **scorer_kwargs) return {} + @property + def is_distillable(self) -> bool: + return False + @property def is_trainable(self) -> bool: return False diff --git a/spacy/pipeline/trainable_pipe.pyx b/spacy/pipeline/trainable_pipe.pyx index 5a28204cf..fbcbe7a17 100644 --- a/spacy/pipeline/trainable_pipe.pyx +++ b/spacy/pipeline/trainable_pipe.pyx @@ -264,6 +264,10 @@ cdef class TrainablePipe(Pipe): """ raise NotImplementedError(Errors.E931.format(parent="Pipe", method="add_label", name=self.name)) + @property + def is_distillable(self) -> bool: + return not (self.__class__.distill is TrainablePipe.distill and self.__class__.get_teacher_student_loss is TrainablePipe.get_teacher_student_loss) + @property def is_trainable(self) -> bool: return True diff --git a/spacy/tests/parser/test_ner.py b/spacy/tests/parser/test_ner.py index 9429fc746..7ef3ed869 100644 --- a/spacy/tests/parser/test_ner.py +++ b/spacy/tests/parser/test_ner.py @@ -617,6 +617,12 @@ def test_overfitting_IO(use_upper): assert ents[1].kb_id == 0 +def test_is_distillable(): + nlp = English() + ner = nlp.add_pipe("ner") + assert ner.is_distillable + + def test_distill(): teacher = English() teacher_ner = teacher.add_pipe("ner") diff --git a/spacy/tests/parser/test_parse.py b/spacy/tests/parser/test_parse.py index 089c4d066..97d112a50 100644 --- a/spacy/tests/parser/test_parse.py +++ b/spacy/tests/parser/test_parse.py @@ -396,6 +396,12 @@ def test_overfitting_IO(pipe_name): assert_equal(batch_deps_1, no_batch_deps) +def test_is_distillable(): + nlp = English() + parser = nlp.add_pipe("parser") + assert parser.is_distillable + + def test_distill(): teacher = English() teacher_parser = teacher.add_pipe("parser") diff --git a/spacy/tests/pipeline/test_edit_tree_lemmatizer.py b/spacy/tests/pipeline/test_edit_tree_lemmatizer.py index 96c83a335..b855c7a26 100644 --- a/spacy/tests/pipeline/test_edit_tree_lemmatizer.py +++ b/spacy/tests/pipeline/test_edit_tree_lemmatizer.py @@ -195,6 +195,12 @@ def test_overfitting_IO(): assert doc4[3].lemma_ == "egg" +def test_is_distillable(): + nlp = English() + lemmatizer = nlp.add_pipe("trainable_lemmatizer") + assert lemmatizer.is_distillable + + def test_distill(): teacher = English() teacher_lemmatizer = teacher.add_pipe("trainable_lemmatizer") diff --git a/spacy/tests/pipeline/test_morphologizer.py b/spacy/tests/pipeline/test_morphologizer.py index 70fc77304..5b9b17c01 100644 --- a/spacy/tests/pipeline/test_morphologizer.py +++ b/spacy/tests/pipeline/test_morphologizer.py @@ -50,6 +50,12 @@ def test_implicit_label(): nlp.initialize(get_examples=lambda: train_examples) +def test_is_distillable(): + nlp = English() + morphologizer = nlp.add_pipe("morphologizer") + assert morphologizer.is_distillable + + def test_no_resize(): nlp = Language() morphologizer = nlp.add_pipe("morphologizer") diff --git a/spacy/tests/pipeline/test_senter.py b/spacy/tests/pipeline/test_senter.py index 3deac9e9a..a771d62fa 100644 --- a/spacy/tests/pipeline/test_senter.py +++ b/spacy/tests/pipeline/test_senter.py @@ -11,6 +11,12 @@ from spacy.pipeline import TrainablePipe from spacy.tests.util import make_tempdir +def test_is_distillable(): + nlp = English() + senter = nlp.add_pipe("senter") + assert senter.is_distillable + + def test_label_types(): nlp = Language() senter = nlp.add_pipe("senter") diff --git a/spacy/tests/pipeline/test_tagger.py b/spacy/tests/pipeline/test_tagger.py index b2fd74142..344859f8d 100644 --- a/spacy/tests/pipeline/test_tagger.py +++ b/spacy/tests/pipeline/test_tagger.py @@ -213,6 +213,12 @@ def test_overfitting_IO(): assert doc3[0].tag_ != "N" +def test_is_distillable(): + nlp = English() + tagger = nlp.add_pipe("tagger") + assert tagger.is_distillable + + def test_distill(): teacher = English() teacher_tagger = teacher.add_pipe("tagger") diff --git a/spacy/tests/pipeline/test_textcat.py b/spacy/tests/pipeline/test_textcat.py index 304209933..9c0eeb171 100644 --- a/spacy/tests/pipeline/test_textcat.py +++ b/spacy/tests/pipeline/test_textcat.py @@ -565,6 +565,12 @@ def test_initialize_examples(name, get_examples, train_data): nlp.initialize(get_examples=get_examples()) +def test_is_distillable(): + nlp = English() + textcat = nlp.add_pipe("textcat") + assert not textcat.is_distillable + + def test_overfitting_IO(): # Simple test to try and quickly overfit the single-label textcat component - ensuring the ML models work correctly fix_random_seed(0)