Clean up span embedding code

This is now cleaner and significantly faster. There's still some messy
parts in the code (particularly variable names), will get to that later.
This commit is contained in:
Paul O'Leary McCann 2021-07-10 19:59:08 +09:00
parent dc1f974d39
commit d7d317a1b5

View File

@ -118,7 +118,10 @@ def build_span_embedder(
max_span_width: int = 20, max_span_width: int = 20,
) -> Model[Tuple[List[Floats2d], List[Doc]], SpanEmbeddings]: ) -> Model[Tuple[List[Floats2d], List[Doc]], SpanEmbeddings]:
return Model( with Model.define_operators({">>": chain, "|": concatenate}):
span_reduce = (extract_spans() >>
(reduce_first() | reduce_last() | reduce_mean()))
model = Model(
"SpanEmbedding", "SpanEmbedding",
forward=span_embeddings_forward, forward=span_embeddings_forward,
attrs={ attrs={
@ -127,7 +130,10 @@ def build_span_embedder(
# mention generator # mention generator
"max_span_width": max_span_width, "max_span_width": max_span_width,
}, },
layers=[span_reduce],
) )
model.set_ref("span_reducer", span_reduce)
return model
def span_embeddings_forward( def span_embeddings_forward(
@ -157,45 +163,26 @@ def span_embeddings_forward(
# TODO support attention here # TODO support attention here
tokvecs = xp.concatenate(tokvecs) tokvecs = xp.concatenate(tokvecs)
spans = [tokvecs[ii:jj] for ii, jj in mentions] tokvecs_r = Ragged(tokvecs, docmenlens)
avgs = [xp.mean(ss, axis=0) for ss in spans] mentions_r = Ragged(mentions, docmenlens)
spanvecs = ops.asarray2f(avgs)
# first and last token embeds span_reduce = model.get_ref("span_reducer")
starts, ends = zip(*[(tokvecs[ii], tokvecs[jj]) for ii, jj in mentions]) spanvecs, span_reduce_back = span_reduce( (tokvecs_r, mentions_r), is_train)
starts = ops.asarray2f(starts) embeds = Ragged(spanvecs, docmenlens)
ends = ops.asarray2f(ends)
concat = xp.concatenate((starts, ends, spanvecs), 1)
embeds = Ragged(concat, docmenlens)
def backprop_span_embed(dY: SpanEmbeddings) -> Tuple[List[Floats2d], List[Doc]]: def backprop_span_embed(dY: SpanEmbeddings) -> Tuple[List[Floats2d], List[Doc]]:
oweights = [] oweights = []
odocs = []
offset = 0 offset = 0
tokoffset = 0 for mlen in dY.vectors.lengths:
for indoc, mlen in zip(docs, dY.vectors.lengths):
hi = offset + mlen hi = offset + mlen
hitok = tokoffset + len(indoc)
odocs.append(indoc) # no change
vecs = dY.vectors.data[offset:hi] vecs = dY.vectors.data[offset:hi]
out, out_idx = span_reduce_back(vecs)
starts = vecs[:, :dim] oweights.append(out.data)
ends = vecs[:, dim : 2 * dim]
spanvecs = vecs[:, 2 * dim :]
out = model.ops.alloc2f(len(indoc), dim)
idxs = dY.indices[offset:hi] - tokoffset
ops.scatter_add(out, idxs[:, 0], starts)
ops.scatter_add(out, idxs[:, 1], ends)
ops.scatter_add(out, idxs.T, spanvecs)
oweights.append(out)
offset = hi offset = hi
tokoffset = hitok return oweights, docs
return oweights, odocs
return SpanEmbeddings(mentions, embeds), backprop_span_embed return SpanEmbeddings(mentions, embeds), backprop_span_embed