Use scatter_add to speed up span embed backprop

This was the slowest part of the code, and using scatter_add here
probably reduces the runtime by 50%.
This commit is contained in:
Paul O'Leary McCann 2021-07-10 18:08:51 +09:00
parent d0b041aff4
commit f34915c1e8

View File

@ -187,14 +187,10 @@ def span_embeddings_forward(
out = model.ops.alloc2f(len(indoc), dim) out = model.ops.alloc2f(len(indoc), dim)
for ii, (start, end) in enumerate(dY.indices[offset:hi]): idxs = dY.indices[offset:hi] - tokoffset
# adjust indexes to align with doc ops.scatter_add(out, idxs[:, 0], starts)
start -= tokoffset ops.scatter_add(out, idxs[:, 1], ends)
end -= tokoffset ops.scatter_add(out, idxs.T, spanvecs)
out[start] += starts[ii]
out[end] += ends[ii]
out[start:end] += spanvecs[ii]
oweights.append(out) oweights.append(out)
offset = hi offset = hi