mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-04 20:30:24 +03:00
Add rehearse method to tok2vec and link listeners
This commit is contained in:
parent
97a9c03398
commit
702de13992
|
@ -1216,7 +1216,6 @@ class Language:
|
|||
self._optimizer = self.create_optimizer()
|
||||
sgd = self._optimizer
|
||||
pipes = list(self.pipeline)
|
||||
random.shuffle(pipes)
|
||||
if component_cfg is None:
|
||||
component_cfg = {}
|
||||
grads = {}
|
||||
|
@ -1340,6 +1339,13 @@ class Language:
|
|||
for name, proc in self.pipeline:
|
||||
if hasattr(proc, "_rehearsal_model"):
|
||||
proc._rehearsal_model = deepcopy(proc.model) # type: ignore[attr-defined]
|
||||
|
||||
# Relink the listeners of rehearsal models to their respective upstream tok2vec component
|
||||
# Otherwise they won't be synced with the tok2vec and throw a mismatched ID error
|
||||
for i, (name1, proc1) in enumerate(self.pipeline):
|
||||
if isinstance(proc1, ty.ListenedToComponent):
|
||||
for name2, proc2 in self.pipeline[i + 1 :]:
|
||||
proc1.find_listeners(proc2)
|
||||
if sgd is not None:
|
||||
self._optimizer = sgd
|
||||
elif self._optimizer is None:
|
||||
|
|
|
@ -59,6 +59,7 @@ class Tok2Vec(TrainablePipe):
|
|||
"""
|
||||
self.vocab = vocab
|
||||
self.model = model
|
||||
self._rehearsal_model = None
|
||||
self.name = name
|
||||
self.listener_map: Dict[str, List["Tok2VecListener"]] = {}
|
||||
self.cfg: Dict[str, Any] = {}
|
||||
|
@ -108,6 +109,11 @@ class Tok2Vec(TrainablePipe):
|
|||
for node in component.model.walk():
|
||||
if isinstance(node, Tok2VecListener) and node.upstream_name in names:
|
||||
self.add_listener(node, component.name)
|
||||
# Make sure to link to Tok2VecListeners from rehearsal models
|
||||
if isinstance(getattr(component, "_rehearsal_model", None), Model):
|
||||
for node in component._rehearsal_model.walk():
|
||||
if isinstance(node, Tok2VecListener) and node.upstream_name in names:
|
||||
self.add_listener(node, component.name + "_rehearsal_model")
|
||||
|
||||
def predict(self, docs: Iterable[Doc]):
|
||||
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||
|
@ -191,6 +197,67 @@ class Tok2Vec(TrainablePipe):
|
|||
self.listeners[-1].receive(batch_id, tokvecs, backprop)
|
||||
return losses
|
||||
|
||||
def rehearse(
|
||||
self,
|
||||
examples: Iterable[Example],
|
||||
*,
|
||||
drop: float = 0.0,
|
||||
sgd: Optional[Optimizer] = None,
|
||||
losses: Optional[Dict[str, float]] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""Perform a "rehearsal" update from a batch of data. Rehearsal updates
|
||||
teach the current model to make predictions similar to an initial model,
|
||||
to try to address the "catastrophic forgetting" problem. This feature is
|
||||
experimental.
|
||||
|
||||
examples (Iterable[Example]): A batch of Example objects.
|
||||
drop (float): The dropout rate.
|
||||
sgd (thinc.api.Optimizer): The optimizer.
|
||||
losses (Dict[str, float]): Optional record of the loss during training.
|
||||
Updated using the component name as the key.
|
||||
RETURNS (Dict[str, float]): The updated losses dictionary.
|
||||
|
||||
TODO DOCS: https://spacy.io/api/tok2vec#rehearse
|
||||
"""
|
||||
if losses is None:
|
||||
losses = {}
|
||||
if self._rehearsal_model is None:
|
||||
return losses
|
||||
validate_examples(examples, "Tok2Vec.rehearsal")
|
||||
docs = [eg.predicted for eg in examples]
|
||||
set_dropout_rate(self.model, drop)
|
||||
tokvecs, bp_tokvecs = self.model.begin_update(docs)
|
||||
target, _ = self._rehearsal_model.begin_update(docs)
|
||||
# TODO tokvecs vs target
|
||||
# How should the output from the rehearsal model influence the results of the tok2vec model?
|
||||
d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
|
||||
losses.setdefault(self.name, 0.0)
|
||||
|
||||
def accumulate_gradient(one_d_tokvecs):
|
||||
"""Accumulate tok2vec loss and gradient. This is passed as a callback
|
||||
to all but the last listener. Only the last one does the backprop.
|
||||
"""
|
||||
nonlocal d_tokvecs
|
||||
for i in range(len(one_d_tokvecs)):
|
||||
d_tokvecs[i] += one_d_tokvecs[i]
|
||||
losses[self.name] += float((one_d_tokvecs[i] ** 2).sum())
|
||||
return [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
|
||||
|
||||
def backprop(one_d_tokvecs):
|
||||
"""Callback to actually do the backprop. Passed to last listener."""
|
||||
accumulate_gradient(one_d_tokvecs)
|
||||
d_docs = bp_tokvecs(d_tokvecs)
|
||||
if sgd is not None:
|
||||
self.finish_update(sgd)
|
||||
return d_docs
|
||||
|
||||
batch_id = Tok2VecListener.get_batch_id(docs)
|
||||
for listener in self.listeners[:-1]:
|
||||
listener.receive(batch_id, tokvecs, accumulate_gradient)
|
||||
if self.listeners:
|
||||
self.listeners[-1].receive(batch_id, tokvecs, backprop)
|
||||
return losses
|
||||
|
||||
def get_loss(self, examples, scores) -> None:
|
||||
pass
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user