From d0eab0b99559b1ebb41b207722e17a8ac1d37484 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Tue, 17 Jan 2023 19:51:00 +0100 Subject: [PATCH] Add `Language.distill` This method is the distillation counterpart of `Language.update`. It takes a teacher `Language` instance and distills the student pipes on the teacher pipes. --- spacy/language.py | 104 +++++++++++++++++++++++++++++++++- spacy/tests/test_language.py | 68 ++++++++++++++++++++++ spacy/ty.py | 19 +++++++ website/docs/api/language.mdx | 28 +++++++++ 4 files changed, 218 insertions(+), 1 deletion(-) diff --git a/spacy/language.py b/spacy/language.py index dcb62aef0..d3b183d46 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -22,7 +22,7 @@ from . import ty from .tokens.underscore import Underscore from .vocab import Vocab, create_vocab from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis -from .training import Example, validate_examples +from .training import Example, validate_examples, validate_distillation_examples from .training.initialize import init_vocab, init_tok2vec from .scorer import Scorer from .util import registry, SimpleFrozenList, _pipe, raise_error, _DEFAULT_EMPTY_PIPES @@ -1018,6 +1018,98 @@ class Language: raise ValueError(Errors.E005.format(name=name, returned_type=type(doc))) return doc + def distill( + self, + teacher: "Language", + examples: Iterable[Example], + *, + drop: float = 0.0, + sgd: Optional[Optimizer] = None, + losses: Optional[Dict[str, float]] = None, + component_cfg: Optional[Dict[str, Dict[str, Any]]] = None, + exclude: Iterable[str] = SimpleFrozenList(), + annotates: Iterable[str] = SimpleFrozenList(), + component_map: Optional[Dict[str, str]] = None, + ): + """Update the models in the pipeline. + teacher (Language): Teacher to distill from. + examples (Iterable[Example]): A batch of examples + drop (float): The dropout rate. + sgd (Optional[Optimizer]): An optimizer. + losses (Optional(Dict[str, float])): Dictionary to update with the loss, + keyed by component. + component_cfg (Optional[Dict[str, Dict[str, Any]]]): Config parameters + for specific pipeline components, keyed by component name. + exclude (Iterable[str]): Names of components that shouldn't be updated. + annotates (Iterable[str]): Names of components that should set + annotations on the predicted examples after updating. + component_map (Optional[Dict[str, str]]): Map student pipe name to + teacher pipe name, only needed for pipes where the student pipe + name does not match the teacher pipe name. + RETURNS (Dict[str, float]): The updated losses dictionary + + DOCS: https://spacy.io/api/language#distill + """ + if component_map is None: + component_map = {} + if losses is None: + losses = {} + if isinstance(examples, list) and len(examples) == 0: + return losses + + validate_distillation_examples(examples, "Language.distill") + examples = _copy_examples(examples) + + if sgd is None: + if self._optimizer is None: + self._optimizer = self.create_optimizer() + sgd = self._optimizer + + if component_cfg is None: + component_cfg = {} + pipe_kwargs = {} + for name, student_proc in self.pipeline: + component_cfg.setdefault(name, {}) + pipe_kwargs[name] = deepcopy(component_cfg[name]) + component_cfg[name].setdefault("drop", drop) + pipe_kwargs[name].setdefault("batch_size", self.batch_size) + + teacher_pipes = dict(teacher.pipeline) + for name, student_proc in self.pipeline: + if ( + name not in exclude + and isinstance(student_proc, ty.DistillableComponent) + and student_proc.is_distillable + ): + # A missing teacher pipe is not an error, some student pipes + # do not need a teacher, such as tok2vec layer losses. + teacher_pipe_name = ( + component_map[name] if name in component_map else name + ) + teacher_pipe = teacher_pipes.get(teacher_pipe_name, None) + student_proc.distill( + teacher_pipe, + examples, + sgd=None, + losses=losses, + **component_cfg[name], + ) + if sgd is not None: + student_proc.finish_update(sgd) + if name in annotates: + for doc, eg in zip( + _pipe( + (eg.predicted for eg in examples), + proc=student_proc, + name=name, + default_error_handler=self.default_error_handler, + kwargs=pipe_kwargs[name], + ), + examples, + ): + eg.predicted = doc + return losses + def disable_pipes(self, *names) -> "DisabledPipes": """Disable one or more pipeline components. If used as a context manager, the pipeline will be restored to the initial state at the end @@ -1243,12 +1335,16 @@ class Language: self, get_examples: Optional[Callable[[], Iterable[Example]]] = None, *, + labels: Optional[Dict[str, Any]] = None, sgd: Optional[Optimizer] = None, ) -> Optimizer: """Initialize the pipe for training, using data examples if available. get_examples (Callable[[], Iterable[Example]]): Optional function that returns gold-standard Example objects. + labels (Dict[str, Any]): labels to pass to pipe initialization, using + the names of the pipes as keys. Overrides labels that are in the + model configuration. sgd (Optional[Optimizer]): An optimizer to use for updates. If not provided, will be created using the .create_optimizer() method. RETURNS (thinc.api.Optimizer): The optimizer. @@ -1293,6 +1389,8 @@ class Language: for name, proc in self.pipeline: if isinstance(proc, ty.InitializableComponent): p_settings = I["components"].get(name, {}) + if labels is not None and name in labels: + p_settings["labels"] = labels[name] p_settings = validate_init_settings( proc.initialize, p_settings, section="components", name=name ) @@ -1726,6 +1824,7 @@ class Language: # using the nlp.config with all defaults. config = util.copy_config(config) orig_pipeline = config.pop("components", {}) + orig_distill = config.pop("distill", None) orig_pretraining = config.pop("pretraining", None) config["components"] = {} if auto_fill: @@ -1734,6 +1833,9 @@ class Language: filled = config filled["components"] = orig_pipeline config["components"] = orig_pipeline + if orig_distill is not None: + filled["distill"] = orig_distill + config["distill"] = orig_distill if orig_pretraining is not None: filled["pretraining"] = orig_pretraining config["pretraining"] = orig_pretraining diff --git a/spacy/tests/test_language.py b/spacy/tests/test_language.py index 03790eb86..1f20fe455 100644 --- a/spacy/tests/test_language.py +++ b/spacy/tests/test_language.py @@ -26,6 +26,12 @@ except ImportError: pass +TAGGER_TRAIN_DATA = [ + ("I like green eggs", {"tags": ["N", "V", "J", "N"]}), + ("Eat blue ham", {"tags": ["V", "J", "N"]}), +] + + def evil_component(doc): if "2" in doc.text: raise ValueError("no dice") @@ -799,3 +805,65 @@ def test_component_return(): nlp.add_pipe("test_component_bad_pipe") with pytest.raises(ValueError, match="instead of a Doc"): nlp("text") + + +@pytest.mark.parametrize("teacher_tagger_name", ["tagger", "teacher_tagger"]) +def test_distill(teacher_tagger_name): + teacher = English() + teacher_tagger = teacher.add_pipe("tagger", name=teacher_tagger_name) + train_examples = [] + for t in TAGGER_TRAIN_DATA: + train_examples.append(Example.from_dict(teacher.make_doc(t[0]), t[1])) + + optimizer = teacher.initialize(get_examples=lambda: train_examples) + + for i in range(50): + losses = {} + teacher.update(train_examples, sgd=optimizer, losses=losses) + assert losses[teacher_tagger_name] < 0.00001 + + student = English() + student_tagger = student.add_pipe("tagger") + student_tagger.min_tree_freq = 1 + student_tagger.initialize( + get_examples=lambda: train_examples, labels=teacher_tagger.label_data + ) + + distill_examples = [ + Example.from_dict(teacher.make_doc(t[0]), {}) for t in TAGGER_TRAIN_DATA + ] + + component_map = ( + None + if teacher_tagger.name == student_tagger.name + else {student_tagger.name: teacher_tagger.name} + ) + + for i in range(50): + losses = {} + student.distill( + teacher, + distill_examples, + sgd=optimizer, + losses=losses, + component_map=component_map, + ) + assert losses["tagger"] < 0.00001 + + test_text = "I like blue eggs" + doc = student(test_text) + assert doc[0].tag_ == "N" + assert doc[1].tag_ == "V" + assert doc[2].tag_ == "J" + assert doc[3].tag_ == "N" + + # Do an extra update to check if annotates works, though we can't really + # validate the resuls, since the annotations are ephemeral. + student.distill( + teacher, + distill_examples, + sgd=optimizer, + losses=losses, + component_map=component_map, + annotates=["tagger"], + ) diff --git a/spacy/ty.py b/spacy/ty.py index 8f2903d78..7f95c84b9 100644 --- a/spacy/ty.py +++ b/spacy/ty.py @@ -27,6 +27,25 @@ class TrainableComponent(Protocol): ... +@runtime_checkable +class DistillableComponent(Protocol): + is_distillable: bool + + def distill( + self, + teacher_pipe: Optional[TrainableComponent], + examples: Iterable["Example"], + *, + drop: float = 0.0, + sgd: Optional[Optimizer] = None, + losses: Optional[Dict[str, float]] = None + ) -> Dict[str, float]: + ... + + def finish_update(self, sgd: Optimizer) -> None: + ... + + @runtime_checkable class InitializableComponent(Protocol): def initialize( diff --git a/website/docs/api/language.mdx b/website/docs/api/language.mdx index a34ea7242..557928782 100644 --- a/website/docs/api/language.mdx +++ b/website/docs/api/language.mdx @@ -333,6 +333,34 @@ and custom registered functions if needed. See the | `component_cfg` | Optional dictionary of keyword arguments for components, keyed by component names. Defaults to `None`. ~~Optional[Dict[str, Dict[str, Any]]]~~ | | **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | +## Language.distill {id="distill",tag="method",version="4"} + +Distill the models in a student pipeline from a teacher pipeline. + +> #### Example +> +> ```python +> +> teacher = spacy.load("en_core_web_lg") +> student = English() +> student.add_pipe("tagger") +> student.distill(teacher, examples, sgd=optimizer) +> ``` + +| Name | Description | +| --------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- | +| `teacher` | The teacher pipeline to distill from. ~~Language~~ | +| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ | +| _keyword-only_ | | +| `drop` | The dropout rate. ~~float~~ | +| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | +| `losses` | Dictionary to update with the loss, keyed by pipeline component. ~~Optional[Dict[str, float]]~~ | +| `component_cfg` | Optional dictionary of keyword arguments for components, keyed by component names. Defaults to `None`. ~~Optional[Dict[str, Dict[str, Any]]]~~ | +| `exclude` | Names of components that shouldn't be updated. Defaults to `[]`. ~~Iterable[str]~~ | +| `annotates` | Names of components that should set annotations on the prediced examples after updating. Defaults to `[]`. ~~Iterable[str]~~ | +| `component_map` | Map student component names to teacher component names, only necessary when the names differ. Defaults to `None`. ~~Iterable[str]~~ | +| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | + ## Language.rehearse {id="rehearse",tag="method,experimental",version="3"} Perform a "rehearsal" update from a batch of data. Rehearsal updates teach the