spaCy/spacy/training/alignment_array.pyx
Daniël de Kok c90dd6f265
Alignment: use a simplified ragged type for performance (#10319)
* Alignment: use a simplified ragged type for performance

This introduces the AlignmentArray type, which is a simplified version
of Ragged that performs better on the simple(r) indexing performed for
alignment.

* AlignmentArray: raise an error when using unsupported index

* AlignmentArray: move error messages to Errors

* AlignmentArray: remove simlified ... with simplifications

* AlignmentArray: fix typo that broke a[n:n] indexing
2022-04-01 09:02:06 +02:00

69 lines
1.9 KiB
Cython

from typing import List
from ..errors import Errors
import numpy
cdef class AlignmentArray:
"""AlignmentArray is similar to Thinc's Ragged with two simplfications:
indexing returns numpy arrays and this type can only be used for CPU arrays.
However, these changes make AlginmentArray more efficient for indexing in a
tight loop."""
__slots__ = []
def __init__(self, alignment: List[List[int]]):
self._lengths = None
self._starts_ends = numpy.zeros(len(alignment) + 1, dtype="i")
cdef int data_len = 0
cdef int outer_len
cdef int idx
for idx, outer in enumerate(alignment):
outer_len = len(outer)
self._starts_ends[idx + 1] = self._starts_ends[idx] + outer_len
data_len += outer_len
self._data = numpy.empty(data_len, dtype="i")
idx = 0
for outer in alignment:
for inner in outer:
self._data[idx] = inner
idx += 1
def __getitem__(self, idx):
starts = self._starts_ends[:-1]
ends = self._starts_ends[1:]
if isinstance(idx, int):
start = starts[idx]
end = ends[idx]
elif isinstance(idx, slice):
if not (idx.step is None or idx.step == 1):
raise ValueError(Errors.E1027)
start = starts[idx]
if len(start) == 0:
return self._data[0:0]
start = start[0]
end = ends[idx][-1]
else:
raise ValueError(Errors.E1028)
return self._data[start:end]
@property
def data(self):
return self._data
@property
def lengths(self):
if self._lengths is None:
self._lengths = self.ends - self.starts
return self._lengths
@property
def ends(self):
return self._starts_ends[1:]
@property
def starts(self):
return self._starts_ends[:-1]