diff --git a/spacy/language.py b/spacy/language.py index e0abfd5e7..849f7810f 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1337,9 +1337,18 @@ class Language: ops = get_current_ops() if self.vocab.vectors.shape[1] >= 1: self.vocab.vectors.to_ops(ops) + + # Create rehearsal models for name, proc in self.pipeline: if hasattr(proc, "_rehearsal_model"): proc._rehearsal_model = deepcopy(proc.model) # type: ignore[attr-defined] + + # Link listeners from rehearsal models to Tok2Vec components + for i, (name1, proc1) in enumerate(self.pipeline): + if isinstance(proc1, ty.ListenedToComponent): + for name2, proc2 in self.pipeline[i + 1 :]: + proc1.find_listeners(proc2) + if sgd is not None: self._optimizer = sgd elif self._optimizer is None: diff --git a/spacy/pipeline/tok2vec.py b/spacy/pipeline/tok2vec.py index c742aaeaa..913205b10 100644 --- a/spacy/pipeline/tok2vec.py +++ b/spacy/pipeline/tok2vec.py @@ -59,6 +59,7 @@ class Tok2Vec(TrainablePipe): """ self.vocab = vocab self.model = model + self._rehearsal_model = None self.name = name self.listener_map: Dict[str, List["Tok2VecListener"]] = {} self.cfg: Dict[str, Any] = {} @@ -108,6 +109,11 @@ class Tok2Vec(TrainablePipe): for node in component.model.walk(): if isinstance(node, Tok2VecListener) and node.upstream_name in names: self.add_listener(node, component.name) + # Make sure to link to Tok2VecListeners from rehearsal models + if isinstance(getattr(component, "_rehearsal_model", None), Model): + for node in component._rehearsal_model.walk(): + if isinstance(node, Tok2VecListener) and node.upstream_name in names: + self.add_listener(node, component.name + "_rehearsal_model") def predict(self, docs: Iterable[Doc]): """Apply the pipeline's model to a batch of docs, without modifying them. @@ -191,6 +197,57 @@ class Tok2Vec(TrainablePipe): self.listeners[-1].receive(batch_id, tokvecs, backprop) return losses + def rehearse( + self, + examples: Iterable[Example], + *, + drop: float = 0.0, + sgd: Optional[Optimizer] = None, + losses: Optional[Dict[str, float]] = None, + ): + """Perform a "rehearsal" update from a batch of data. Rehearsal updates + teach the current model to make predictions similar to an initial model, + to try to address the "catastrophic forgetting" problem. This feature is + experimental. + + examples (Iterable[Example]): A batch of Example objects. + drop (float): The dropout rate. + sgd (thinc.api.Optimizer): The optimizer. + losses (Dict[str, float]): Optional record of the loss during training. + Updated using the component name as the key. + RETURNS (Dict[str, float]): The updated losses dictionary. + + DOCS: https://spacy.io/api/tok2vec#rehearse + """ + if losses is None: + losses = {} + if self._rehearsal_model is None: + return losses + validate_examples(examples, "Tok2Vec.rehearse") + docs = [eg.predicted for eg in examples] + set_dropout_rate(self.model, drop) + tokvecs, bp_tokvecs = self.model.begin_update(docs) + target, _ = self._rehearsal_model.begin_update(docs) + d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs] + losses.setdefault(self.name, 0.0) + + for i in range(len(target)): + d_tokvecs[i] += target[i] + losses[self.name] += float((target[i] ** 2).sum()) + + def empty_backprop(_): + return [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs] + + batch_id = Tok2VecListener.get_batch_id(docs) + for listener in self.listeners: + listener.receive(batch_id, tokvecs, empty_backprop) + + bp_tokvecs(d_tokvecs) + if sgd is not None: + self.finish_update(sgd) + + return losses + def get_loss(self, examples, scores) -> None: pass diff --git a/spacy/tests/training/test_rehearse.py b/spacy/tests/training/test_rehearse.py index 5ac7fc217..11aaace20 100644 --- a/spacy/tests/training/test_rehearse.py +++ b/spacy/tests/training/test_rehearse.py @@ -1,6 +1,7 @@ import pytest import spacy +from thinc.api import Config from typing import List from spacy.training import Example @@ -148,6 +149,86 @@ REHEARSE_DATA = [ ), ] +TEXTCAT_MULTILABEL_LISTENER_CONFIG = """ +[nlp] +lang = "en" +pipeline = ["tok2vec","textcat_multilabel"] +disabled = [] +before_creation = null +after_creation = null +after_pipeline_creation = null +batch_size = 1000 +tokenizer = {"@tokenizers":"spacy.Tokenizer.v1"} +[components] +[components.textcat_multilabel] +factory = "textcat_multilabel" +threshold = 0.5 +[components.textcat_multilabel.model] +@architectures = "spacy.TextCatEnsemble.v2" +nO = null +[components.textcat_multilabel.model.linear_model] +@architectures = "spacy.TextCatBOW.v2" +exclusive_classes = false +ngram_size = 1 +no_output_layer = false +[components.textcat_multilabel.model.tok2vec] +@architectures = "spacy.Tok2VecListener.v1" +width = 64 +upstream = "*" +[components.tok2vec] +factory = "tok2vec" +[components.tok2vec.model] +@architectures = "spacy.Tok2Vec.v2" +[components.tok2vec.model.embed] +@architectures = "spacy.MultiHashEmbed.v2" +width = 64 +attrs = ["ORTH", "SHAPE"] +rows = [5000, 2500] +include_static_vectors = true +[components.tok2vec.model.encode] +@architectures = "spacy.MishWindowEncoder.v2" +width = 64 +depth = 4 +window_size = 1 +""" + +NER_LISTENER_CONFIG = """ +[nlp] +lang = "en" +pipeline = ["tok2vec","ner"] +batch_size = 1000 +[components] +[components.tok2vec] +factory = "tok2vec" +[components.tok2vec.model] +@architectures = "spacy.Tok2Vec.v2" +[components.tok2vec.model.embed] +@architectures = "spacy.MultiHashEmbed.v2" +width = ${components.tok2vec.model.encode.width} +attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"] +rows = [5000, 1000, 2500, 2500] +include_static_vectors = false +[components.tok2vec.model.encode] +@architectures = "spacy.MaxoutWindowEncoder.v2" +width = 96 +depth = 4 +window_size = 1 +maxout_pieces = 3 +[components.ner] +factory = "ner" +[components.ner.model] +@architectures = "spacy.TransitionBasedParser.v2" +state_type = "ner" +extra_state_tokens = false +hidden_width = 64 +maxout_pieces = 2 +use_upper = true +nO = null +[components.ner.model.tok2vec] +@architectures = "spacy.Tok2VecListener.v1" +width = ${components.tok2vec.model.encode.width} +""" + def _add_ner_label(ner, data): for _, annotations in data: @@ -197,7 +278,11 @@ def _optimize(nlp, component: str, data: List, rehearse: bool): doc = nlp.make_doc(text) example = Example.from_dict(doc, annotation) if rehearse: - nlp.rehearse([example], sgd=optimizer) + nlp.update([example], sgd=None) + nlp.rehearse([example], sgd=None) + for name, proc in nlp.pipeline: + if proc.is_trainable and proc.model not in (True, False, None): + proc.finish_update(optimizer) else: nlp.update([example], sgd=optimizer) return nlp @@ -209,3 +294,21 @@ def test_rehearse(component): nlp.add_pipe(component) nlp = _optimize(nlp, component, TRAIN_DATA, False) _optimize(nlp, component, REHEARSE_DATA, True) + + +@pytest.mark.issue(12044) +def test_rehearse_textcat_multilabel_listener(): + """Test nlp.rehearse on a textcat_multilabel pipeline with a tok2vec listener""" + config = Config().from_str(TEXTCAT_MULTILABEL_LISTENER_CONFIG) + nlp = spacy.blank("en", config=config) + nlp = _optimize(nlp, "textcat_multilabel", TRAIN_DATA, False) + _optimize(nlp, "textcat_multilabel", REHEARSE_DATA, True) + + +@pytest.mark.issue(12044) +def test_rehearse_ner_listener(): + """Test nlp.rehearse on a ner pipeline with a tok2vec listener""" + config = Config().from_str(NER_LISTENER_CONFIG) + nlp = spacy.blank("en", config=config) + nlp = _optimize(nlp, "ner", TRAIN_DATA, False) + _optimize(nlp, "ner", REHEARSE_DATA, True) diff --git a/website/docs/api/tok2vec.mdx b/website/docs/api/tok2vec.mdx index a1bb1265e..c67c19c41 100644 --- a/website/docs/api/tok2vec.mdx +++ b/website/docs/api/tok2vec.mdx @@ -205,6 +205,31 @@ Delegates to [`predict`](/api/tok2vec#predict). | `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ | | **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | +## Tok2Vec.rehearse {id="rehearse",tag="method,experimental",version="3.5.1"} + +Perform a "rehearsal" update from a batch of data. Rehearsal updates teach the +current model to make predictions similar to an initial model, to try to address +the "catastrophic forgetting" problem. Please note that `Tok2Vec.rehearse` needs to be used together with `Tok2Vec.update`. This feature is experimental. + +> #### Example +> +> ```python +> tok2vec = nlp.add_pipe("tok2vec") +> optimizer = nlp.resume_training() +> update_losses = tok2vec.update(examples, sgd=None) +> rehearse_losses = tok2vec.rehearse(examples, sgd=None) +> tok2vec.finish_update(optimizer) +> ``` + +| Name | Description | +| -------------- | ------------------------------------------------------------------------------------------------------------- | +| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ | +| _keyword-only_ | | +| `drop` | The dropout rate. ~~float~~ | +| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | +| `losses` | Dictionary to update with the loss, keyed by pipeline component. ~~Optional[Dict[str, float]]~~ | +| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | + ## Tok2Vec.create_optimizer {id="create_optimizer",tag="method"} Create an optimizer for the pipeline component.