diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index f83014344..31643d248 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -356,9 +356,11 @@ def build_take_vecs() -> Model[SpanEmbeddings, Floats2d]: def take_vecs_forward(model, inputs: SpanEmbeddings, is_train) -> Floats2d: + idxs = inputs.indices + lens = inputs.vectors.lengths def backprop(dY: Floats2d) -> SpanEmbeddings: - vecs = Ragged(dY, inputs.vectors.lengths) - return SpanEmbeddings(inputs.indices, vecs) + vecs = Ragged(dY, lens) + return SpanEmbeddings(idxs, vecs) return inputs.vectors.data, backprop