From d7d317a1b5fcd250828db3b486500aa5db580478 Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Sat, 10 Jul 2021 19:59:08 +0900 Subject: [PATCH] Clean up span embedding code This is now cleaner and significantly faster. There's still some messy parts in the code (particularly variable names), will get to that later. --- spacy/ml/models/coref.py | 45 ++++++++++++++-------------------------- 1 file changed, 16 insertions(+), 29 deletions(-) diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index 66039564e..37f6ff0ff 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -118,7 +118,10 @@ def build_span_embedder( max_span_width: int = 20, ) -> Model[Tuple[List[Floats2d], List[Doc]], SpanEmbeddings]: - return Model( + with Model.define_operators({">>": chain, "|": concatenate}): + span_reduce = (extract_spans() >> + (reduce_first() | reduce_last() | reduce_mean())) + model = Model( "SpanEmbedding", forward=span_embeddings_forward, attrs={ @@ -127,7 +130,10 @@ def build_span_embedder( # mention generator "max_span_width": max_span_width, }, + layers=[span_reduce], ) + model.set_ref("span_reducer", span_reduce) + return model def span_embeddings_forward( @@ -157,45 +163,26 @@ def span_embeddings_forward( # TODO support attention here tokvecs = xp.concatenate(tokvecs) - spans = [tokvecs[ii:jj] for ii, jj in mentions] - avgs = [xp.mean(ss, axis=0) for ss in spans] - spanvecs = ops.asarray2f(avgs) + tokvecs_r = Ragged(tokvecs, docmenlens) + mentions_r = Ragged(mentions, docmenlens) - # first and last token embeds - starts, ends = zip(*[(tokvecs[ii], tokvecs[jj]) for ii, jj in mentions]) + span_reduce = model.get_ref("span_reducer") + spanvecs, span_reduce_back = span_reduce( (tokvecs_r, mentions_r), is_train) - starts = ops.asarray2f(starts) - ends = ops.asarray2f(ends) - concat = xp.concatenate((starts, ends, spanvecs), 1) - embeds = Ragged(concat, docmenlens) + embeds = Ragged(spanvecs, docmenlens) def backprop_span_embed(dY: SpanEmbeddings) -> Tuple[List[Floats2d], List[Doc]]: oweights = [] - odocs = [] offset = 0 - tokoffset = 0 - for indoc, mlen in zip(docs, dY.vectors.lengths): + for mlen in dY.vectors.lengths: hi = offset + mlen - hitok = tokoffset + len(indoc) - odocs.append(indoc) # no change vecs = dY.vectors.data[offset:hi] - - starts = vecs[:, :dim] - ends = vecs[:, dim : 2 * dim] - spanvecs = vecs[:, 2 * dim :] - - out = model.ops.alloc2f(len(indoc), dim) - - idxs = dY.indices[offset:hi] - tokoffset - ops.scatter_add(out, idxs[:, 0], starts) - ops.scatter_add(out, idxs[:, 1], ends) - ops.scatter_add(out, idxs.T, spanvecs) - oweights.append(out) + out, out_idx = span_reduce_back(vecs) + oweights.append(out.data) offset = hi - tokoffset = hitok - return oweights, odocs + return oweights, docs return SpanEmbeddings(mentions, embeds), backprop_span_embed