mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Only run backprop once when shared tok2vec weights (#5331)
Previously, pipelines with shared tok2vec weights would call the tok2vec backprop callback multiple times, once for each pipeline component. This caused errors for PyTorch, and was inefficient. Instead, accumulate the gradient for all but one component, and just call the callback once.
This commit is contained in:
parent
6918d99b6c
commit
b2ef6100af
|
@ -103,20 +103,30 @@ class Tok2Vec(Pipe):
|
|||
set_dropout_rate(self.model, drop)
|
||||
tokvecs, bp_tokvecs = self.model.begin_update(docs)
|
||||
|
||||
def capture_losses(d_tokvecs):
|
||||
"""Accumulate tok2vec loss before doing backprop."""
|
||||
l2_loss = sum((d_t2v ** 2).sum() for d_t2v in d_tokvecs)
|
||||
if self.name in losses:
|
||||
losses[self.name] += l2_loss / len(d_tokvecs)
|
||||
else:
|
||||
losses[self.name] = l2_loss / len(d_tokvecs)
|
||||
return bp_tokvecs(d_tokvecs)
|
||||
d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
|
||||
losses.setdefault(self.name, 0.0)
|
||||
|
||||
def accumulate_gradient(one_d_tokvecs):
|
||||
"""Accumulate tok2vec loss and gradient. This is passed as a callback
|
||||
to all but the last listener. Only the last one does the backprop.
|
||||
"""
|
||||
nonlocal d_tokvecs
|
||||
for i in range(len(one_d_tokvecs)):
|
||||
d_tokvecs[i] += one_d_tokvecs[i]
|
||||
losses[self.name] += float((one_d_tokvecs[i] ** 2).sum())
|
||||
|
||||
def backprop(one_d_tokvecs):
|
||||
"""Callback to actually do the backprop. Passed to last listener."""
|
||||
accumulate_gradient(one_d_tokvecs)
|
||||
d_docs = bp_tokvecs(d_tokvecs)
|
||||
if sgd is not None:
|
||||
self.model.finish_update(sgd)
|
||||
return d_docs
|
||||
|
||||
batch_id = Tok2VecListener.get_batch_id(docs)
|
||||
for listener in self.listeners:
|
||||
listener.receive(batch_id, tokvecs, capture_losses)
|
||||
if sgd is not None:
|
||||
self.model.finish_update(sgd)
|
||||
for listener in self.listeners[:-1]:
|
||||
listener.receive(batch_id, tokvecs, accumulate_gradient)
|
||||
self.listeners[-1].receive(batch_id, tokvecs, backprop)
|
||||
if set_annotations:
|
||||
self.set_annotations(docs, tokvecs)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user