mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-04 20:30:24 +03:00
Add rehearse method to tok2vec and link listeners
This commit is contained in:
parent
97a9c03398
commit
702de13992
|
@ -1216,7 +1216,6 @@ class Language:
|
||||||
self._optimizer = self.create_optimizer()
|
self._optimizer = self.create_optimizer()
|
||||||
sgd = self._optimizer
|
sgd = self._optimizer
|
||||||
pipes = list(self.pipeline)
|
pipes = list(self.pipeline)
|
||||||
random.shuffle(pipes)
|
|
||||||
if component_cfg is None:
|
if component_cfg is None:
|
||||||
component_cfg = {}
|
component_cfg = {}
|
||||||
grads = {}
|
grads = {}
|
||||||
|
@ -1340,6 +1339,13 @@ class Language:
|
||||||
for name, proc in self.pipeline:
|
for name, proc in self.pipeline:
|
||||||
if hasattr(proc, "_rehearsal_model"):
|
if hasattr(proc, "_rehearsal_model"):
|
||||||
proc._rehearsal_model = deepcopy(proc.model) # type: ignore[attr-defined]
|
proc._rehearsal_model = deepcopy(proc.model) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
# Relink the listeners of rehearsal models to their respective upstream tok2vec component
|
||||||
|
# Otherwise they won't be synced with the tok2vec and throw a mismatched ID error
|
||||||
|
for i, (name1, proc1) in enumerate(self.pipeline):
|
||||||
|
if isinstance(proc1, ty.ListenedToComponent):
|
||||||
|
for name2, proc2 in self.pipeline[i + 1 :]:
|
||||||
|
proc1.find_listeners(proc2)
|
||||||
if sgd is not None:
|
if sgd is not None:
|
||||||
self._optimizer = sgd
|
self._optimizer = sgd
|
||||||
elif self._optimizer is None:
|
elif self._optimizer is None:
|
||||||
|
|
|
@ -59,6 +59,7 @@ class Tok2Vec(TrainablePipe):
|
||||||
"""
|
"""
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self._rehearsal_model = None
|
||||||
self.name = name
|
self.name = name
|
||||||
self.listener_map: Dict[str, List["Tok2VecListener"]] = {}
|
self.listener_map: Dict[str, List["Tok2VecListener"]] = {}
|
||||||
self.cfg: Dict[str, Any] = {}
|
self.cfg: Dict[str, Any] = {}
|
||||||
|
@ -108,6 +109,11 @@ class Tok2Vec(TrainablePipe):
|
||||||
for node in component.model.walk():
|
for node in component.model.walk():
|
||||||
if isinstance(node, Tok2VecListener) and node.upstream_name in names:
|
if isinstance(node, Tok2VecListener) and node.upstream_name in names:
|
||||||
self.add_listener(node, component.name)
|
self.add_listener(node, component.name)
|
||||||
|
# Make sure to link to Tok2VecListeners from rehearsal models
|
||||||
|
if isinstance(getattr(component, "_rehearsal_model", None), Model):
|
||||||
|
for node in component._rehearsal_model.walk():
|
||||||
|
if isinstance(node, Tok2VecListener) and node.upstream_name in names:
|
||||||
|
self.add_listener(node, component.name + "_rehearsal_model")
|
||||||
|
|
||||||
def predict(self, docs: Iterable[Doc]):
|
def predict(self, docs: Iterable[Doc]):
|
||||||
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||||
|
@ -191,6 +197,67 @@ class Tok2Vec(TrainablePipe):
|
||||||
self.listeners[-1].receive(batch_id, tokvecs, backprop)
|
self.listeners[-1].receive(batch_id, tokvecs, backprop)
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
|
def rehearse(
|
||||||
|
self,
|
||||||
|
examples: Iterable[Example],
|
||||||
|
*,
|
||||||
|
drop: float = 0.0,
|
||||||
|
sgd: Optional[Optimizer] = None,
|
||||||
|
losses: Optional[Dict[str, float]] = None,
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""Perform a "rehearsal" update from a batch of data. Rehearsal updates
|
||||||
|
teach the current model to make predictions similar to an initial model,
|
||||||
|
to try to address the "catastrophic forgetting" problem. This feature is
|
||||||
|
experimental.
|
||||||
|
|
||||||
|
examples (Iterable[Example]): A batch of Example objects.
|
||||||
|
drop (float): The dropout rate.
|
||||||
|
sgd (thinc.api.Optimizer): The optimizer.
|
||||||
|
losses (Dict[str, float]): Optional record of the loss during training.
|
||||||
|
Updated using the component name as the key.
|
||||||
|
RETURNS (Dict[str, float]): The updated losses dictionary.
|
||||||
|
|
||||||
|
TODO DOCS: https://spacy.io/api/tok2vec#rehearse
|
||||||
|
"""
|
||||||
|
if losses is None:
|
||||||
|
losses = {}
|
||||||
|
if self._rehearsal_model is None:
|
||||||
|
return losses
|
||||||
|
validate_examples(examples, "Tok2Vec.rehearsal")
|
||||||
|
docs = [eg.predicted for eg in examples]
|
||||||
|
set_dropout_rate(self.model, drop)
|
||||||
|
tokvecs, bp_tokvecs = self.model.begin_update(docs)
|
||||||
|
target, _ = self._rehearsal_model.begin_update(docs)
|
||||||
|
# TODO tokvecs vs target
|
||||||
|
# How should the output from the rehearsal model influence the results of the tok2vec model?
|
||||||
|
d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
|
||||||
|
losses.setdefault(self.name, 0.0)
|
||||||
|
|
||||||
|
def accumulate_gradient(one_d_tokvecs):
|
||||||
|
"""Accumulate tok2vec loss and gradient. This is passed as a callback
|
||||||
|
to all but the last listener. Only the last one does the backprop.
|
||||||
|
"""
|
||||||
|
nonlocal d_tokvecs
|
||||||
|
for i in range(len(one_d_tokvecs)):
|
||||||
|
d_tokvecs[i] += one_d_tokvecs[i]
|
||||||
|
losses[self.name] += float((one_d_tokvecs[i] ** 2).sum())
|
||||||
|
return [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
|
||||||
|
|
||||||
|
def backprop(one_d_tokvecs):
|
||||||
|
"""Callback to actually do the backprop. Passed to last listener."""
|
||||||
|
accumulate_gradient(one_d_tokvecs)
|
||||||
|
d_docs = bp_tokvecs(d_tokvecs)
|
||||||
|
if sgd is not None:
|
||||||
|
self.finish_update(sgd)
|
||||||
|
return d_docs
|
||||||
|
|
||||||
|
batch_id = Tok2VecListener.get_batch_id(docs)
|
||||||
|
for listener in self.listeners[:-1]:
|
||||||
|
listener.receive(batch_id, tokvecs, accumulate_gradient)
|
||||||
|
if self.listeners:
|
||||||
|
self.listeners[-1].receive(batch_id, tokvecs, backprop)
|
||||||
|
return losses
|
||||||
|
|
||||||
def get_loss(self, examples, scores) -> None:
|
def get_loss(self, examples, scores) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user