From 0c7c503a904ffb313c6fb14caf91e2c2fbbe3571 Mon Sep 17 00:00:00 2001 From: kadarakos Date: Wed, 26 Apr 2023 19:10:23 +0000 Subject: [PATCH] avoid nesting then flattening --- spacy/ml/extract_spans.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spacy/ml/extract_spans.py b/spacy/ml/extract_spans.py index d5e9bc07c..ec674115b 100644 --- a/spacy/ml/extract_spans.py +++ b/spacy/ml/extract_spans.py @@ -57,9 +57,9 @@ def _get_span_indices(ops, spans: Ragged, lengths: Ints1d) -> Ints1d: for i, length in enumerate(lengths): spans_i = spans[i].dataXd + offset for j in range(spans_i.shape[0]): - indices.append(ops.xp.arange(spans_i[j, 0], spans_i[j, 1])) # type: ignore[call-overload, index] + indices += list(range(spans_i[j, 0], spans_i[j, 1])) # type: ignore[call-overload, index] offset += length - return ops.flatten(indices, dtype="i", ndim_if_empty=1) + return ops.xp.asarray(indices, dtype="int") def _ensure_cpu(spans: Ragged, lengths: Ints1d) -> Tuple[Ragged, Ints1d]: