mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-19 20:52:23 +03:00
Clean up span embedding code
This is now cleaner and significantly faster. There's still some messy parts in the code (particularly variable names), will get to that later.
This commit is contained in:
parent
dc1f974d39
commit
d7d317a1b5
|
@ -118,7 +118,10 @@ def build_span_embedder(
|
||||||
max_span_width: int = 20,
|
max_span_width: int = 20,
|
||||||
) -> Model[Tuple[List[Floats2d], List[Doc]], SpanEmbeddings]:
|
) -> Model[Tuple[List[Floats2d], List[Doc]], SpanEmbeddings]:
|
||||||
|
|
||||||
return Model(
|
with Model.define_operators({">>": chain, "|": concatenate}):
|
||||||
|
span_reduce = (extract_spans() >>
|
||||||
|
(reduce_first() | reduce_last() | reduce_mean()))
|
||||||
|
model = Model(
|
||||||
"SpanEmbedding",
|
"SpanEmbedding",
|
||||||
forward=span_embeddings_forward,
|
forward=span_embeddings_forward,
|
||||||
attrs={
|
attrs={
|
||||||
|
@ -127,7 +130,10 @@ def build_span_embedder(
|
||||||
# mention generator
|
# mention generator
|
||||||
"max_span_width": max_span_width,
|
"max_span_width": max_span_width,
|
||||||
},
|
},
|
||||||
|
layers=[span_reduce],
|
||||||
)
|
)
|
||||||
|
model.set_ref("span_reducer", span_reduce)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def span_embeddings_forward(
|
def span_embeddings_forward(
|
||||||
|
@ -157,45 +163,26 @@ def span_embeddings_forward(
|
||||||
|
|
||||||
# TODO support attention here
|
# TODO support attention here
|
||||||
tokvecs = xp.concatenate(tokvecs)
|
tokvecs = xp.concatenate(tokvecs)
|
||||||
spans = [tokvecs[ii:jj] for ii, jj in mentions]
|
tokvecs_r = Ragged(tokvecs, docmenlens)
|
||||||
avgs = [xp.mean(ss, axis=0) for ss in spans]
|
mentions_r = Ragged(mentions, docmenlens)
|
||||||
spanvecs = ops.asarray2f(avgs)
|
|
||||||
|
|
||||||
# first and last token embeds
|
span_reduce = model.get_ref("span_reducer")
|
||||||
starts, ends = zip(*[(tokvecs[ii], tokvecs[jj]) for ii, jj in mentions])
|
spanvecs, span_reduce_back = span_reduce( (tokvecs_r, mentions_r), is_train)
|
||||||
|
|
||||||
starts = ops.asarray2f(starts)
|
embeds = Ragged(spanvecs, docmenlens)
|
||||||
ends = ops.asarray2f(ends)
|
|
||||||
concat = xp.concatenate((starts, ends, spanvecs), 1)
|
|
||||||
embeds = Ragged(concat, docmenlens)
|
|
||||||
|
|
||||||
def backprop_span_embed(dY: SpanEmbeddings) -> Tuple[List[Floats2d], List[Doc]]:
|
def backprop_span_embed(dY: SpanEmbeddings) -> Tuple[List[Floats2d], List[Doc]]:
|
||||||
|
|
||||||
oweights = []
|
oweights = []
|
||||||
odocs = []
|
|
||||||
offset = 0
|
offset = 0
|
||||||
tokoffset = 0
|
for mlen in dY.vectors.lengths:
|
||||||
for indoc, mlen in zip(docs, dY.vectors.lengths):
|
|
||||||
hi = offset + mlen
|
hi = offset + mlen
|
||||||
hitok = tokoffset + len(indoc)
|
|
||||||
odocs.append(indoc) # no change
|
|
||||||
vecs = dY.vectors.data[offset:hi]
|
vecs = dY.vectors.data[offset:hi]
|
||||||
|
out, out_idx = span_reduce_back(vecs)
|
||||||
starts = vecs[:, :dim]
|
oweights.append(out.data)
|
||||||
ends = vecs[:, dim : 2 * dim]
|
|
||||||
spanvecs = vecs[:, 2 * dim :]
|
|
||||||
|
|
||||||
out = model.ops.alloc2f(len(indoc), dim)
|
|
||||||
|
|
||||||
idxs = dY.indices[offset:hi] - tokoffset
|
|
||||||
ops.scatter_add(out, idxs[:, 0], starts)
|
|
||||||
ops.scatter_add(out, idxs[:, 1], ends)
|
|
||||||
ops.scatter_add(out, idxs.T, spanvecs)
|
|
||||||
oweights.append(out)
|
|
||||||
|
|
||||||
offset = hi
|
offset = hi
|
||||||
tokoffset = hitok
|
return oweights, docs
|
||||||
return oweights, odocs
|
|
||||||
|
|
||||||
return SpanEmbeddings(mentions, embeds), backprop_span_embed
|
return SpanEmbeddings(mentions, embeds), backprop_span_embed
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user