Add rehearse method to tok2vec and link listeners

This commit is contained in:
thomashacker 2023-01-05 12:27:14 +01:00
parent 97a9c03398
commit 702de13992
2 changed files with 74 additions and 1 deletions

View File

@ -1216,7 +1216,6 @@ class Language:
self._optimizer = self.create_optimizer()
sgd = self._optimizer
pipes = list(self.pipeline)
random.shuffle(pipes)
if component_cfg is None:
component_cfg = {}
grads = {}
@ -1340,6 +1339,13 @@ class Language:
for name, proc in self.pipeline:
if hasattr(proc, "_rehearsal_model"):
proc._rehearsal_model = deepcopy(proc.model) # type: ignore[attr-defined]
# Relink the listeners of rehearsal models to their respective upstream tok2vec component
# Otherwise they won't be synced with the tok2vec and throw a mismatched ID error
for i, (name1, proc1) in enumerate(self.pipeline):
if isinstance(proc1, ty.ListenedToComponent):
for name2, proc2 in self.pipeline[i + 1 :]:
proc1.find_listeners(proc2)
if sgd is not None:
self._optimizer = sgd
elif self._optimizer is None:

View File

@ -59,6 +59,7 @@ class Tok2Vec(TrainablePipe):
"""
self.vocab = vocab
self.model = model
self._rehearsal_model = None
self.name = name
self.listener_map: Dict[str, List["Tok2VecListener"]] = {}
self.cfg: Dict[str, Any] = {}
@ -108,6 +109,11 @@ class Tok2Vec(TrainablePipe):
for node in component.model.walk():
if isinstance(node, Tok2VecListener) and node.upstream_name in names:
self.add_listener(node, component.name)
# Make sure to link to Tok2VecListeners from rehearsal models
if isinstance(getattr(component, "_rehearsal_model", None), Model):
for node in component._rehearsal_model.walk():
if isinstance(node, Tok2VecListener) and node.upstream_name in names:
self.add_listener(node, component.name + "_rehearsal_model")
def predict(self, docs: Iterable[Doc]):
"""Apply the pipeline's model to a batch of docs, without modifying them.
@ -191,6 +197,67 @@ class Tok2Vec(TrainablePipe):
self.listeners[-1].receive(batch_id, tokvecs, backprop)
return losses
def rehearse(
self,
examples: Iterable[Example],
*,
drop: float = 0.0,
sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None,
) -> Dict[str, float]:
"""Perform a "rehearsal" update from a batch of data. Rehearsal updates
teach the current model to make predictions similar to an initial model,
to try to address the "catastrophic forgetting" problem. This feature is
experimental.
examples (Iterable[Example]): A batch of Example objects.
drop (float): The dropout rate.
sgd (thinc.api.Optimizer): The optimizer.
losses (Dict[str, float]): Optional record of the loss during training.
Updated using the component name as the key.
RETURNS (Dict[str, float]): The updated losses dictionary.
TODO DOCS: https://spacy.io/api/tok2vec#rehearse
"""
if losses is None:
losses = {}
if self._rehearsal_model is None:
return losses
validate_examples(examples, "Tok2Vec.rehearsal")
docs = [eg.predicted for eg in examples]
set_dropout_rate(self.model, drop)
tokvecs, bp_tokvecs = self.model.begin_update(docs)
target, _ = self._rehearsal_model.begin_update(docs)
# TODO tokvecs vs target
# How should the output from the rehearsal model influence the results of the tok2vec model?
d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
losses.setdefault(self.name, 0.0)
def accumulate_gradient(one_d_tokvecs):
"""Accumulate tok2vec loss and gradient. This is passed as a callback
to all but the last listener. Only the last one does the backprop.
"""
nonlocal d_tokvecs
for i in range(len(one_d_tokvecs)):
d_tokvecs[i] += one_d_tokvecs[i]
losses[self.name] += float((one_d_tokvecs[i] ** 2).sum())
return [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
def backprop(one_d_tokvecs):
"""Callback to actually do the backprop. Passed to last listener."""
accumulate_gradient(one_d_tokvecs)
d_docs = bp_tokvecs(d_tokvecs)
if sgd is not None:
self.finish_update(sgd)
return d_docs
batch_id = Tok2VecListener.get_batch_id(docs)
for listener in self.listeners[:-1]:
listener.receive(batch_id, tokvecs, accumulate_gradient)
if self.listeners:
self.listeners[-1].receive(batch_id, tokvecs, backprop)
return losses
def get_loss(self, examples, scores) -> None:
pass