Add rehearse method to tok2vec and link listeners

This commit is contained in:
thomashacker 2023-01-05 12:27:14 +01:00
parent 97a9c03398
commit 702de13992
2 changed files with 74 additions and 1 deletions

View File

@ -1216,7 +1216,6 @@ class Language:
self._optimizer = self.create_optimizer() self._optimizer = self.create_optimizer()
sgd = self._optimizer sgd = self._optimizer
pipes = list(self.pipeline) pipes = list(self.pipeline)
random.shuffle(pipes)
if component_cfg is None: if component_cfg is None:
component_cfg = {} component_cfg = {}
grads = {} grads = {}
@ -1340,6 +1339,13 @@ class Language:
for name, proc in self.pipeline: for name, proc in self.pipeline:
if hasattr(proc, "_rehearsal_model"): if hasattr(proc, "_rehearsal_model"):
proc._rehearsal_model = deepcopy(proc.model) # type: ignore[attr-defined] proc._rehearsal_model = deepcopy(proc.model) # type: ignore[attr-defined]
# Relink the listeners of rehearsal models to their respective upstream tok2vec component
# Otherwise they won't be synced with the tok2vec and throw a mismatched ID error
for i, (name1, proc1) in enumerate(self.pipeline):
if isinstance(proc1, ty.ListenedToComponent):
for name2, proc2 in self.pipeline[i + 1 :]:
proc1.find_listeners(proc2)
if sgd is not None: if sgd is not None:
self._optimizer = sgd self._optimizer = sgd
elif self._optimizer is None: elif self._optimizer is None:

View File

@ -59,6 +59,7 @@ class Tok2Vec(TrainablePipe):
""" """
self.vocab = vocab self.vocab = vocab
self.model = model self.model = model
self._rehearsal_model = None
self.name = name self.name = name
self.listener_map: Dict[str, List["Tok2VecListener"]] = {} self.listener_map: Dict[str, List["Tok2VecListener"]] = {}
self.cfg: Dict[str, Any] = {} self.cfg: Dict[str, Any] = {}
@ -108,6 +109,11 @@ class Tok2Vec(TrainablePipe):
for node in component.model.walk(): for node in component.model.walk():
if isinstance(node, Tok2VecListener) and node.upstream_name in names: if isinstance(node, Tok2VecListener) and node.upstream_name in names:
self.add_listener(node, component.name) self.add_listener(node, component.name)
# Make sure to link to Tok2VecListeners from rehearsal models
if isinstance(getattr(component, "_rehearsal_model", None), Model):
for node in component._rehearsal_model.walk():
if isinstance(node, Tok2VecListener) and node.upstream_name in names:
self.add_listener(node, component.name + "_rehearsal_model")
def predict(self, docs: Iterable[Doc]): def predict(self, docs: Iterable[Doc]):
"""Apply the pipeline's model to a batch of docs, without modifying them. """Apply the pipeline's model to a batch of docs, without modifying them.
@ -191,6 +197,67 @@ class Tok2Vec(TrainablePipe):
self.listeners[-1].receive(batch_id, tokvecs, backprop) self.listeners[-1].receive(batch_id, tokvecs, backprop)
return losses return losses
def rehearse(
self,
examples: Iterable[Example],
*,
drop: float = 0.0,
sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None,
) -> Dict[str, float]:
"""Perform a "rehearsal" update from a batch of data. Rehearsal updates
teach the current model to make predictions similar to an initial model,
to try to address the "catastrophic forgetting" problem. This feature is
experimental.
examples (Iterable[Example]): A batch of Example objects.
drop (float): The dropout rate.
sgd (thinc.api.Optimizer): The optimizer.
losses (Dict[str, float]): Optional record of the loss during training.
Updated using the component name as the key.
RETURNS (Dict[str, float]): The updated losses dictionary.
TODO DOCS: https://spacy.io/api/tok2vec#rehearse
"""
if losses is None:
losses = {}
if self._rehearsal_model is None:
return losses
validate_examples(examples, "Tok2Vec.rehearsal")
docs = [eg.predicted for eg in examples]
set_dropout_rate(self.model, drop)
tokvecs, bp_tokvecs = self.model.begin_update(docs)
target, _ = self._rehearsal_model.begin_update(docs)
# TODO tokvecs vs target
# How should the output from the rehearsal model influence the results of the tok2vec model?
d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
losses.setdefault(self.name, 0.0)
def accumulate_gradient(one_d_tokvecs):
"""Accumulate tok2vec loss and gradient. This is passed as a callback
to all but the last listener. Only the last one does the backprop.
"""
nonlocal d_tokvecs
for i in range(len(one_d_tokvecs)):
d_tokvecs[i] += one_d_tokvecs[i]
losses[self.name] += float((one_d_tokvecs[i] ** 2).sum())
return [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
def backprop(one_d_tokvecs):
"""Callback to actually do the backprop. Passed to last listener."""
accumulate_gradient(one_d_tokvecs)
d_docs = bp_tokvecs(d_tokvecs)
if sgd is not None:
self.finish_update(sgd)
return d_docs
batch_id = Tok2VecListener.get_batch_id(docs)
for listener in self.listeners[:-1]:
listener.receive(batch_id, tokvecs, accumulate_gradient)
if self.listeners:
self.listeners[-1].receive(batch_id, tokvecs, backprop)
return losses
def get_loss(self, examples, scores) -> None: def get_loss(self, examples, scores) -> None:
pass pass