mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-19 20:52:23 +03:00
Help out python gc in coref backprop
This commit is contained in:
parent
fa92daf052
commit
8c5df622d8
|
@ -101,6 +101,7 @@ def tuplify_init(model, X, Y) -> Model:
|
|||
layer.initialize(X=X)
|
||||
return model
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpanEmbeddings:
|
||||
indices: Ints2d # Array with 2 columns (for start and end index)
|
||||
|
@ -272,20 +273,25 @@ def coarse_prune(
|
|||
|
||||
out = SpanEmbeddings(top_spans, Ragged(top_vecs, sellens))
|
||||
|
||||
# save some variables so the embeds can be garbage collected
|
||||
idxlen = spanembeds.indices.shape[0]
|
||||
vecshape = spanembeds.vectors.data.shape
|
||||
indices = spanembeds.indices
|
||||
veclens = out.vectors.lengths
|
||||
|
||||
def coarse_prune_backprop(
|
||||
dY: Tuple[Floats1d, SpanEmbeddings]
|
||||
) -> Tuple[Floats1d, SpanEmbeddings]:
|
||||
ll = spanembeds.indices.shape[0]
|
||||
|
||||
dYscores, dYembeds = dY
|
||||
|
||||
dXscores = model.ops.alloc1f(ll)
|
||||
dXscores = model.ops.alloc1f(idxlen)
|
||||
dXscores[selected] = dYscores.squeeze()
|
||||
|
||||
dXvecs = model.ops.alloc2f(*spanembeds.vectors.data.shape)
|
||||
dXvecs = model.ops.alloc2f(*vecshape)
|
||||
dXvecs[selected] = dYembeds.vectors.data
|
||||
rout = Ragged(dXvecs, out.vectors.lengths)
|
||||
dXembeds = SpanEmbeddings(spanembeds.indices, rout)
|
||||
rout = Ragged(dXvecs, veclens)
|
||||
dXembeds = SpanEmbeddings(indices, rout)
|
||||
|
||||
# inflate for mention scorer
|
||||
dXscores = model.ops.xp.expand_dims(dXscores, 1)
|
||||
|
@ -381,15 +387,20 @@ def ant_scorer_forward(
|
|||
offset += ll
|
||||
backprops.append((prod_back, pw_sum_back))
|
||||
|
||||
# save vars for gc
|
||||
vecshape = vecs.data.shape
|
||||
veclens = vecs.lengths
|
||||
scoreshape = mscores.shape
|
||||
|
||||
def backprop(
|
||||
dYs: Tuple[List[Tuple[Floats2d, Ints2d]], Ints2d]
|
||||
) -> Tuple[Floats2d, SpanEmbeddings]:
|
||||
dYscores, dYembeds = dYs
|
||||
dXembeds = Ragged(ops.alloc2f(*vecs.data.shape), vecs.lengths)
|
||||
dXscores = ops.alloc1f(*mscores.shape)
|
||||
dXembeds = Ragged(ops.alloc2f(*vecshape), veclens)
|
||||
dXscores = ops.alloc1f(*scoreshape)
|
||||
|
||||
offset = 0
|
||||
for dy, (prod_back, pw_sum_back), ll in zip(dYscores, backprops, vecs.lengths):
|
||||
for dy, (prod_back, pw_sum_back), ll in zip(dYscores, backprops, veclens):
|
||||
# I'm not undoing the operations in the right order here.
|
||||
dyscore, dyidx = dy
|
||||
# the full score grid is square
|
||||
|
@ -427,7 +438,7 @@ def pairwise_sum(ops, mention_scores: Floats1d) -> Tuple[Floats2d, Callable]:
|
|||
out = ops.alloc1f(dim)
|
||||
for ii in range(dim):
|
||||
out[ii] = d_pwsum[:, ii].sum() + d_pwsum[ii, :].sum()
|
||||
#XXX maybe subtract d_pwsum[ii,ii] to avoid double counting?
|
||||
# XXX maybe subtract d_pwsum[ii,ii] to avoid double counting?
|
||||
|
||||
return out
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user