Break pairwise operations into pseudolayers

This makes their scope tighter and more contained, and has the nice side
effect that fewer things need to be passed around for backprop.
This commit is contained in:
Paul O'Leary McCann 2021-05-20 15:59:51 +09:00
parent d22acee4f7
commit fa92daf052

View File

@ -88,7 +88,8 @@ def tuplify_forward(model, X, is_train):
return tuple(Ys), backprop_tuplify
#TODO make more robust, see chain
# TODO make more robust, see chain
def tuplify_init(model, X, Y) -> Model:
if X is None and Y is None:
for layer in model.layers:
@ -149,7 +150,7 @@ def span_embeddings_forward(
tokvecs, docs = inputs
#TODO fix this
# TODO fix this
dim = tokvecs[0].shape[1]
get_mentions = model.attrs["get_mentions"]
@ -157,6 +158,7 @@ def span_embeddings_forward(
mentions = ops.alloc2i(0, 2)
total_length = 0
docmenlens = [] # number of mentions per doc
for doc in docs:
starts, ends = get_mentions(doc, max_span_width)
docmenlens.append(len(starts))
@ -350,13 +352,11 @@ def ant_scorer_forward(
# first calculate the pairwise product scores
cvecs = vecs.data[offset:hi]
source, source_b = bilinear(cvecs, is_train)
target, target_b = dropout(cvecs, is_train)
pw_prod = xp.matmul(source, target.T)
pw_prod, prod_back = pairwise_product(bilinear, dropout, cvecs, is_train)
# now calculate the pairwise mention scores
ms = mscores[offset:hi].squeeze()
pw_sum = xp.expand_dims(ms, 1) + xp.expand_dims(ms, 0)
pw_sum, pw_sum_back = pairwise_sum(ops, ms)
# make a mask so antecedents precede referrents
ant_range = xp.arange(0, cvecs.shape[0])
@ -379,7 +379,7 @@ def ant_scorer_forward(
# garbage collected when the loop exits).
offset += ll
backprops.append((source_b, target_b, source, target))
backprops.append((prod_back, pw_sum_back))
def backprop(
dYs: Tuple[List[Tuple[Floats2d, Ints2d]], Ints2d]
@ -389,9 +389,7 @@ def ant_scorer_forward(
dXscores = ops.alloc1f(*mscores.shape)
offset = 0
for dy, (source_b, target_b, source, target), ll in zip(
dYscores, backprops, vecs.lengths
):
for dy, (prod_back, pw_sum_back), ll in zip(dYscores, backprops, vecs.lengths):
# I'm not undoing the operations in the right order here.
dyscore, dyidx = dy
# the full score grid is square
@ -402,18 +400,51 @@ def ant_scorer_forward(
for ii, (ridx, rscores) in enumerate(zip(dyidx, dyscore)):
fullscore[ii][ridx] = rscores
dS = source_b(fullscore @ target)
dT = target_b(fullscore @ source)
dXembeds.data[offset : offset + ll] = dS + dT
dXembeds.data[offset : offset + ll] = prod_back(fullscore)
dXscores[offset : offset + ll] = pw_sum_back(fullscore)
# The gradient can be distributed over all the rows and columns here,
# so aggregate it
section = dXscores[offset : offset + ll]
for ii in range(ll):
section[ii] = fullscore[:, ii].sum() + fullscore[ii, :].sum()
offset += ll
# make it fit back into the linear
dXscores = xp.expand_dims(dXscores, 1)
return (dXscores, SpanEmbeddings(sembeds.indices, dXembeds))
return (out, sembeds.indices), backprop
def pairwise_sum(ops, mention_scores: Floats1d) -> Tuple[Floats2d, Callable]:
"""Find the most likely mention-antecedent pairs."""
# This doesn't use multiplication because two items with low mention scores
# don't make a good candidate pair.
pw_sum = ops.xp.expand_dims(mention_scores, 1) + ops.xp.expand_dims(
mention_scores, 0
)
def backward(d_pwsum: Floats2d) -> Floats1d:
# For the backward pass, the gradient is distributed over the whole row and
# column, so pull it all in.
dim = d_pwsum.shape[0]
out = ops.alloc1f(dim)
for ii in range(dim):
out[ii] = d_pwsum[:, ii].sum() + d_pwsum[ii, :].sum()
#XXX maybe subtract d_pwsum[ii,ii] to avoid double counting?
return out
return pw_sum, backward
def pairwise_product(bilinear, dropout, vecs: Floats2d, is_train):
# A neat side effect of this is that we don't have to pass the backprops
# around separately because the closure handles them.
source, source_b = bilinear(vecs, is_train)
target, target_b = dropout(vecs, is_train)
pw_prod = bilinear.ops.xp.matmul(source, target.T)
def backward(d_prod: Floats2d) -> Floats2d:
dS = source_b(d_prod @ target)
dT = target_b(d_prod @ source)
dX = dS + dT
return dX
return pw_prod, backward