diff --git a/spacy/language.py b/spacy/language.py index e0abfd5e7..a277041a4 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1216,7 +1216,6 @@ class Language: self._optimizer = self.create_optimizer() sgd = self._optimizer pipes = list(self.pipeline) - random.shuffle(pipes) if component_cfg is None: component_cfg = {} grads = {} @@ -1340,6 +1339,13 @@ class Language: for name, proc in self.pipeline: if hasattr(proc, "_rehearsal_model"): proc._rehearsal_model = deepcopy(proc.model) # type: ignore[attr-defined] + + # Relink the listeners of rehearsal models to their respective upstream tok2vec component + # Otherwise they won't be synced with the tok2vec and throw a mismatched ID error + for i, (name1, proc1) in enumerate(self.pipeline): + if isinstance(proc1, ty.ListenedToComponent): + for name2, proc2 in self.pipeline[i + 1 :]: + proc1.find_listeners(proc2) if sgd is not None: self._optimizer = sgd elif self._optimizer is None: diff --git a/spacy/pipeline/tok2vec.py b/spacy/pipeline/tok2vec.py index c742aaeaa..ae2884a8f 100644 --- a/spacy/pipeline/tok2vec.py +++ b/spacy/pipeline/tok2vec.py @@ -59,6 +59,7 @@ class Tok2Vec(TrainablePipe): """ self.vocab = vocab self.model = model + self._rehearsal_model = None self.name = name self.listener_map: Dict[str, List["Tok2VecListener"]] = {} self.cfg: Dict[str, Any] = {} @@ -108,6 +109,11 @@ class Tok2Vec(TrainablePipe): for node in component.model.walk(): if isinstance(node, Tok2VecListener) and node.upstream_name in names: self.add_listener(node, component.name) + # Make sure to link to Tok2VecListeners from rehearsal models + if isinstance(getattr(component, "_rehearsal_model", None), Model): + for node in component._rehearsal_model.walk(): + if isinstance(node, Tok2VecListener) and node.upstream_name in names: + self.add_listener(node, component.name + "_rehearsal_model") def predict(self, docs: Iterable[Doc]): """Apply the pipeline's model to a batch of docs, without modifying them. @@ -191,6 +197,67 @@ class Tok2Vec(TrainablePipe): self.listeners[-1].receive(batch_id, tokvecs, backprop) return losses + def rehearse( + self, + examples: Iterable[Example], + *, + drop: float = 0.0, + sgd: Optional[Optimizer] = None, + losses: Optional[Dict[str, float]] = None, + ) -> Dict[str, float]: + """Perform a "rehearsal" update from a batch of data. Rehearsal updates + teach the current model to make predictions similar to an initial model, + to try to address the "catastrophic forgetting" problem. This feature is + experimental. + + examples (Iterable[Example]): A batch of Example objects. + drop (float): The dropout rate. + sgd (thinc.api.Optimizer): The optimizer. + losses (Dict[str, float]): Optional record of the loss during training. + Updated using the component name as the key. + RETURNS (Dict[str, float]): The updated losses dictionary. + + TODO DOCS: https://spacy.io/api/tok2vec#rehearse + """ + if losses is None: + losses = {} + if self._rehearsal_model is None: + return losses + validate_examples(examples, "Tok2Vec.rehearsal") + docs = [eg.predicted for eg in examples] + set_dropout_rate(self.model, drop) + tokvecs, bp_tokvecs = self.model.begin_update(docs) + target, _ = self._rehearsal_model.begin_update(docs) + # TODO tokvecs vs target + # How should the output from the rehearsal model influence the results of the tok2vec model? + d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs] + losses.setdefault(self.name, 0.0) + + def accumulate_gradient(one_d_tokvecs): + """Accumulate tok2vec loss and gradient. This is passed as a callback + to all but the last listener. Only the last one does the backprop. + """ + nonlocal d_tokvecs + for i in range(len(one_d_tokvecs)): + d_tokvecs[i] += one_d_tokvecs[i] + losses[self.name] += float((one_d_tokvecs[i] ** 2).sum()) + return [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs] + + def backprop(one_d_tokvecs): + """Callback to actually do the backprop. Passed to last listener.""" + accumulate_gradient(one_d_tokvecs) + d_docs = bp_tokvecs(d_tokvecs) + if sgd is not None: + self.finish_update(sgd) + return d_docs + + batch_id = Tok2VecListener.get_batch_id(docs) + for listener in self.listeners[:-1]: + listener.receive(batch_id, tokvecs, accumulate_gradient) + if self.listeners: + self.listeners[-1].receive(batch_id, tokvecs, backprop) + return losses + def get_loss(self, examples, scores) -> None: pass