Help out python gc in coref backprop

This commit is contained in:
Paul O'Leary McCann 2021-05-20 16:40:55 +09:00
parent fa92daf052
commit 8c5df622d8

View File

@ -101,6 +101,7 @@ def tuplify_init(model, X, Y) -> Model:
layer.initialize(X=X)
return model
@dataclass
class SpanEmbeddings:
indices: Ints2d # Array with 2 columns (for start and end index)
@ -272,20 +273,25 @@ def coarse_prune(
out = SpanEmbeddings(top_spans, Ragged(top_vecs, sellens))
# save some variables so the embeds can be garbage collected
idxlen = spanembeds.indices.shape[0]
vecshape = spanembeds.vectors.data.shape
indices = spanembeds.indices
veclens = out.vectors.lengths
def coarse_prune_backprop(
dY: Tuple[Floats1d, SpanEmbeddings]
) -> Tuple[Floats1d, SpanEmbeddings]:
ll = spanembeds.indices.shape[0]
dYscores, dYembeds = dY
dXscores = model.ops.alloc1f(ll)
dXscores = model.ops.alloc1f(idxlen)
dXscores[selected] = dYscores.squeeze()
dXvecs = model.ops.alloc2f(*spanembeds.vectors.data.shape)
dXvecs = model.ops.alloc2f(*vecshape)
dXvecs[selected] = dYembeds.vectors.data
rout = Ragged(dXvecs, out.vectors.lengths)
dXembeds = SpanEmbeddings(spanembeds.indices, rout)
rout = Ragged(dXvecs, veclens)
dXembeds = SpanEmbeddings(indices, rout)
# inflate for mention scorer
dXscores = model.ops.xp.expand_dims(dXscores, 1)
@ -381,15 +387,20 @@ def ant_scorer_forward(
offset += ll
backprops.append((prod_back, pw_sum_back))
# save vars for gc
vecshape = vecs.data.shape
veclens = vecs.lengths
scoreshape = mscores.shape
def backprop(
dYs: Tuple[List[Tuple[Floats2d, Ints2d]], Ints2d]
) -> Tuple[Floats2d, SpanEmbeddings]:
dYscores, dYembeds = dYs
dXembeds = Ragged(ops.alloc2f(*vecs.data.shape), vecs.lengths)
dXscores = ops.alloc1f(*mscores.shape)
dXembeds = Ragged(ops.alloc2f(*vecshape), veclens)
dXscores = ops.alloc1f(*scoreshape)
offset = 0
for dy, (prod_back, pw_sum_back), ll in zip(dYscores, backprops, vecs.lengths):
for dy, (prod_back, pw_sum_back), ll in zip(dYscores, backprops, veclens):
# I'm not undoing the operations in the right order here.
dyscore, dyidx = dy
# the full score grid is square
@ -427,7 +438,7 @@ def pairwise_sum(ops, mention_scores: Floats1d) -> Tuple[Floats2d, Callable]:
out = ops.alloc1f(dim)
for ii in range(dim):
out[ii] = d_pwsum[:, ii].sum() + d_pwsum[ii, :].sum()
#XXX maybe subtract d_pwsum[ii,ii] to avoid double counting?
# XXX maybe subtract d_pwsum[ii,ii] to avoid double counting?
return out