diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index 719750ecb..66039564e 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -187,14 +187,10 @@ def span_embeddings_forward( out = model.ops.alloc2f(len(indoc), dim) - for ii, (start, end) in enumerate(dY.indices[offset:hi]): - # adjust indexes to align with doc - start -= tokoffset - end -= tokoffset - - out[start] += starts[ii] - out[end] += ends[ii] - out[start:end] += spanvecs[ii] + idxs = dY.indices[offset:hi] - tokoffset + ops.scatter_add(out, idxs[:, 0], starts) + ops.scatter_add(out, idxs[:, 1], ends) + ops.scatter_add(out, idxs.T, spanvecs) oweights.append(out) offset = hi