diff --git a/spacy/ml/extract_spans.py b/spacy/ml/extract_spans.py index d5e9bc07c..af6be78db 100644 --- a/spacy/ml/extract_spans.py +++ b/spacy/ml/extract_spans.py @@ -1,4 +1,4 @@ -from typing import Tuple, Callable +from typing import List, Tuple, Callable from thinc.api import Model, to_numpy from thinc.types import Ragged, Ints1d @@ -52,14 +52,14 @@ def _get_span_indices(ops, spans: Ragged, lengths: Ints1d) -> Ints1d: indices will be [5, 6, 7, 8, 8, 9]. """ spans, lengths = _ensure_cpu(spans, lengths) - indices = [] + indices: List[int] = [] offset = 0 for i, length in enumerate(lengths): spans_i = spans[i].dataXd + offset for j in range(spans_i.shape[0]): - indices.append(ops.xp.arange(spans_i[j, 0], spans_i[j, 1])) # type: ignore[call-overload, index] + indices.extend(range(spans_i[j, 0], spans_i[j, 1])) # type: ignore[arg-type, call-overload] offset += length - return ops.flatten(indices, dtype="i", ndim_if_empty=1) + return ops.asarray1i(indices) def _ensure_cpu(spans: Ragged, lengths: Ints1d) -> Tuple[Ragged, Ints1d]: