mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-11 09:00:36 +03:00
Make doc.from_array several times faster
This commit is contained in:
parent
fa86aa581d
commit
6d821b2e55
|
@ -806,12 +806,14 @@ cdef class Doc:
|
||||||
|
|
||||||
if SENT_START in attrs and HEAD in attrs:
|
if SENT_START in attrs and HEAD in attrs:
|
||||||
raise ValueError(Errors.E032)
|
raise ValueError(Errors.E032)
|
||||||
cdef int i, col, abs_head_index
|
cdef int i, col
|
||||||
|
cdef int32_t abs_head_index
|
||||||
cdef attr_id_t attr_id
|
cdef attr_id_t attr_id
|
||||||
cdef TokenC* tokens = self.c
|
cdef TokenC* tokens = self.c
|
||||||
cdef int length = len(array)
|
cdef int length = len(array)
|
||||||
if length != len(self):
|
if length != len(self):
|
||||||
raise ValueError("Cannot set array values longer than the document.")
|
raise ValueError("Cannot set array values longer than the document.")
|
||||||
|
|
||||||
# Get set up for fast loading
|
# Get set up for fast loading
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
cdef int n_attrs = len(attrs)
|
cdef int n_attrs = len(attrs)
|
||||||
|
@ -822,33 +824,52 @@ cdef class Doc:
|
||||||
attr_ids[i] = attr_id
|
attr_ids[i] = attr_id
|
||||||
if len(array.shape) == 1:
|
if len(array.shape) == 1:
|
||||||
array = array.reshape((array.size, 1))
|
array = array.reshape((array.size, 1))
|
||||||
|
cdef np.ndarray transposed_array = numpy.ascontiguousarray(array.T)
|
||||||
|
values = <const uint64_t*>transposed_array.data
|
||||||
|
stride = transposed_array.shape[1]
|
||||||
# Check that all heads are within the document bounds
|
# Check that all heads are within the document bounds
|
||||||
if HEAD in attrs:
|
if HEAD in attrs:
|
||||||
col = attrs.index(HEAD)
|
col = attrs.index(HEAD)
|
||||||
for i in range(length):
|
for i in range(length):
|
||||||
# cast index to signed int
|
# cast index to signed int
|
||||||
abs_head_index = numpy.int32(array[i, col]) + i
|
abs_head_index = <int32_t>values[col * stride + i]
|
||||||
|
abs_head_index += i
|
||||||
if abs_head_index < 0 or abs_head_index >= length:
|
if abs_head_index < 0 or abs_head_index >= length:
|
||||||
raise ValueError(Errors.E190.format(index=i, value=array[i, col], rel_head_index=numpy.int32(array[i, col])))
|
raise ValueError(
|
||||||
|
Errors.E190.format(
|
||||||
|
index=i,
|
||||||
|
value=array[i, col],
|
||||||
|
rel_head_index=abs_head_index-i
|
||||||
|
)
|
||||||
|
)
|
||||||
# Do TAG first. This lets subsequent loop override stuff like POS, LEMMA
|
# Do TAG first. This lets subsequent loop override stuff like POS, LEMMA
|
||||||
if TAG in attrs:
|
if TAG in attrs:
|
||||||
col = attrs.index(TAG)
|
col = attrs.index(TAG)
|
||||||
for i in range(length):
|
for i in range(length):
|
||||||
if array[i, col] != 0:
|
value = values[col * stride + i]
|
||||||
self.vocab.morphology.assign_tag(&tokens[i], array[i, col])
|
if value != 0:
|
||||||
|
self.vocab.morphology.assign_tag(&tokens[i], value)
|
||||||
# Verify ENT_IOB are proper integers
|
# Verify ENT_IOB are proper integers
|
||||||
if ENT_IOB in attrs:
|
if ENT_IOB in attrs:
|
||||||
iob_strings = Token.iob_strings()
|
iob_strings = Token.iob_strings()
|
||||||
col = attrs.index(ENT_IOB)
|
col = attrs.index(ENT_IOB)
|
||||||
|
n_iob_strings = len(iob_strings)
|
||||||
for i in range(length):
|
for i in range(length):
|
||||||
if array[i, col] not in range(0, len(iob_strings)):
|
value = values[col * stride + i]
|
||||||
raise ValueError(Errors.E982.format(values=iob_strings, value=array[i, col]))
|
if value < 0 or value >= n_iob_strings:
|
||||||
|
raise ValueError(
|
||||||
|
Errors.E982.format(
|
||||||
|
values=iob_strings,
|
||||||
|
value=value
|
||||||
|
)
|
||||||
|
)
|
||||||
# Now load the data
|
# Now load the data
|
||||||
for i in range(length):
|
for i in range(length):
|
||||||
token = &self.c[i]
|
token = &self.c[i]
|
||||||
for j in range(n_attrs):
|
for j in range(n_attrs):
|
||||||
if attr_ids[j] != TAG:
|
if attr_ids[j] != TAG:
|
||||||
Token.set_struct_attr(token, attr_ids[j], array[i, j])
|
value = values[j * stride + i]
|
||||||
|
Token.set_struct_attr(token, attr_ids[j], value)
|
||||||
# Set flags
|
# Set flags
|
||||||
self.is_parsed = bool(self.is_parsed or HEAD in attrs)
|
self.is_parsed = bool(self.is_parsed or HEAD in attrs)
|
||||||
self.is_tagged = bool(self.is_tagged or TAG in attrs or POS in attrs)
|
self.is_tagged = bool(self.is_tagged or TAG in attrs or POS in attrs)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user