Make doc.from_array several times faster

This commit is contained in:
Matthew Honnibal 2020-06-20 20:17:13 +02:00
parent fa86aa581d
commit 6d821b2e55

View File

@ -806,12 +806,14 @@ cdef class Doc:
if SENT_START in attrs and HEAD in attrs:
raise ValueError(Errors.E032)
cdef int i, col, abs_head_index
cdef int i, col
cdef int32_t abs_head_index
cdef attr_id_t attr_id
cdef TokenC* tokens = self.c
cdef int length = len(array)
if length != len(self):
raise ValueError("Cannot set array values longer than the document.")
# Get set up for fast loading
cdef Pool mem = Pool()
cdef int n_attrs = len(attrs)
@ -822,33 +824,52 @@ cdef class Doc:
attr_ids[i] = attr_id
if len(array.shape) == 1:
array = array.reshape((array.size, 1))
cdef np.ndarray transposed_array = numpy.ascontiguousarray(array.T)
values = <const uint64_t*>transposed_array.data
stride = transposed_array.shape[1]
# Check that all heads are within the document bounds
if HEAD in attrs:
col = attrs.index(HEAD)
for i in range(length):
# cast index to signed int
abs_head_index = numpy.int32(array[i, col]) + i
abs_head_index = <int32_t>values[col * stride + i]
abs_head_index += i
if abs_head_index < 0 or abs_head_index >= length:
raise ValueError(Errors.E190.format(index=i, value=array[i, col], rel_head_index=numpy.int32(array[i, col])))
raise ValueError(
Errors.E190.format(
index=i,
value=array[i, col],
rel_head_index=abs_head_index-i
)
)
# Do TAG first. This lets subsequent loop override stuff like POS, LEMMA
if TAG in attrs:
col = attrs.index(TAG)
for i in range(length):
if array[i, col] != 0:
self.vocab.morphology.assign_tag(&tokens[i], array[i, col])
value = values[col * stride + i]
if value != 0:
self.vocab.morphology.assign_tag(&tokens[i], value)
# Verify ENT_IOB are proper integers
if ENT_IOB in attrs:
iob_strings = Token.iob_strings()
col = attrs.index(ENT_IOB)
n_iob_strings = len(iob_strings)
for i in range(length):
if array[i, col] not in range(0, len(iob_strings)):
raise ValueError(Errors.E982.format(values=iob_strings, value=array[i, col]))
value = values[col * stride + i]
if value < 0 or value >= n_iob_strings:
raise ValueError(
Errors.E982.format(
values=iob_strings,
value=value
)
)
# Now load the data
for i in range(length):
token = &self.c[i]
for j in range(n_attrs):
if attr_ids[j] != TAG:
Token.set_struct_attr(token, attr_ids[j], array[i, j])
value = values[j * stride + i]
Token.set_struct_attr(token, attr_ids[j], value)
# Set flags
self.is_parsed = bool(self.is_parsed or HEAD in attrs)
self.is_tagged = bool(self.is_tagged or TAG in attrs or POS in attrs)