mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Spancat speed improvement (#12577)
* avoid nesting then flattening
* mypy fix
* Apply suggestions from code review
* Add type for indices
* Run full matrix for mypy
* Add back modified type: ignore
* Revert "Run full matrix for mypy"
This reverts commit e218873d04
.
---------
Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
parent
a8dfc66135
commit
34d1164b0e
|
@ -1,4 +1,4 @@
|
|||
from typing import Tuple, Callable
|
||||
from typing import List, Tuple, Callable
|
||||
from thinc.api import Model, to_numpy
|
||||
from thinc.types import Ragged, Ints1d
|
||||
|
||||
|
@ -52,14 +52,14 @@ def _get_span_indices(ops, spans: Ragged, lengths: Ints1d) -> Ints1d:
|
|||
indices will be [5, 6, 7, 8, 8, 9].
|
||||
"""
|
||||
spans, lengths = _ensure_cpu(spans, lengths)
|
||||
indices = []
|
||||
indices: List[int] = []
|
||||
offset = 0
|
||||
for i, length in enumerate(lengths):
|
||||
spans_i = spans[i].dataXd + offset
|
||||
for j in range(spans_i.shape[0]):
|
||||
indices.append(ops.xp.arange(spans_i[j, 0], spans_i[j, 1])) # type: ignore[call-overload, index]
|
||||
indices.extend(range(spans_i[j, 0], spans_i[j, 1])) # type: ignore[arg-type, call-overload]
|
||||
offset += length
|
||||
return ops.flatten(indices, dtype="i", ndim_if_empty=1)
|
||||
return ops.asarray1i(indices)
|
||||
|
||||
|
||||
def _ensure_cpu(spans: Ragged, lengths: Ints1d) -> Tuple[Ragged, Ints1d]:
|
||||
|
|
Loading…
Reference in New Issue
Block a user