Spancat speed improvement (#12577)

* avoid nesting then flattening

* mypy fix

* Apply suggestions from code review

* Add type for indices

* Run full matrix for mypy

* Add back modified type: ignore

* Revert "Run full matrix for mypy"

This reverts commit e218873d04.

---------

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
kadarakos 2023-04-27 15:27:13 +02:00 committed by GitHub
parent a8dfc66135
commit 34d1164b0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,4 +1,4 @@
from typing import Tuple, Callable from typing import List, Tuple, Callable
from thinc.api import Model, to_numpy from thinc.api import Model, to_numpy
from thinc.types import Ragged, Ints1d from thinc.types import Ragged, Ints1d
@ -52,14 +52,14 @@ def _get_span_indices(ops, spans: Ragged, lengths: Ints1d) -> Ints1d:
indices will be [5, 6, 7, 8, 8, 9]. indices will be [5, 6, 7, 8, 8, 9].
""" """
spans, lengths = _ensure_cpu(spans, lengths) spans, lengths = _ensure_cpu(spans, lengths)
indices = [] indices: List[int] = []
offset = 0 offset = 0
for i, length in enumerate(lengths): for i, length in enumerate(lengths):
spans_i = spans[i].dataXd + offset spans_i = spans[i].dataXd + offset
for j in range(spans_i.shape[0]): for j in range(spans_i.shape[0]):
indices.append(ops.xp.arange(spans_i[j, 0], spans_i[j, 1])) # type: ignore[call-overload, index] indices.extend(range(spans_i[j, 0], spans_i[j, 1])) # type: ignore[arg-type, call-overload]
offset += length offset += length
return ops.flatten(indices, dtype="i", ndim_if_empty=1) return ops.asarray1i(indices)
def _ensure_cpu(spans: Ragged, lengths: Ints1d) -> Tuple[Ragged, Ints1d]: def _ensure_cpu(spans: Ragged, lengths: Ints1d) -> Tuple[Ragged, Ints1d]: