From 6d821b2e5559151f28880da0ff4a90e391e87657 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sat, 20 Jun 2020 20:17:13 +0200 Subject: [PATCH] Make doc.from_array several times faster --- spacy/tokens/doc.pyx | 37 +++++++++++++++++++++++++++++-------- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index 686f3be54..72a16b854 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -806,12 +806,14 @@ cdef class Doc: if SENT_START in attrs and HEAD in attrs: raise ValueError(Errors.E032) - cdef int i, col, abs_head_index + cdef int i, col + cdef int32_t abs_head_index cdef attr_id_t attr_id cdef TokenC* tokens = self.c cdef int length = len(array) if length != len(self): raise ValueError("Cannot set array values longer than the document.") + # Get set up for fast loading cdef Pool mem = Pool() cdef int n_attrs = len(attrs) @@ -822,33 +824,52 @@ cdef class Doc: attr_ids[i] = attr_id if len(array.shape) == 1: array = array.reshape((array.size, 1)) + cdef np.ndarray transposed_array = numpy.ascontiguousarray(array.T) + values = transposed_array.data + stride = transposed_array.shape[1] # Check that all heads are within the document bounds if HEAD in attrs: col = attrs.index(HEAD) for i in range(length): # cast index to signed int - abs_head_index = numpy.int32(array[i, col]) + i + abs_head_index = values[col * stride + i] + abs_head_index += i if abs_head_index < 0 or abs_head_index >= length: - raise ValueError(Errors.E190.format(index=i, value=array[i, col], rel_head_index=numpy.int32(array[i, col]))) + raise ValueError( + Errors.E190.format( + index=i, + value=array[i, col], + rel_head_index=abs_head_index-i + ) + ) # Do TAG first. This lets subsequent loop override stuff like POS, LEMMA if TAG in attrs: col = attrs.index(TAG) for i in range(length): - if array[i, col] != 0: - self.vocab.morphology.assign_tag(&tokens[i], array[i, col]) + value = values[col * stride + i] + if value != 0: + self.vocab.morphology.assign_tag(&tokens[i], value) # Verify ENT_IOB are proper integers if ENT_IOB in attrs: iob_strings = Token.iob_strings() col = attrs.index(ENT_IOB) + n_iob_strings = len(iob_strings) for i in range(length): - if array[i, col] not in range(0, len(iob_strings)): - raise ValueError(Errors.E982.format(values=iob_strings, value=array[i, col])) + value = values[col * stride + i] + if value < 0 or value >= n_iob_strings: + raise ValueError( + Errors.E982.format( + values=iob_strings, + value=value + ) + ) # Now load the data for i in range(length): token = &self.c[i] for j in range(n_attrs): if attr_ids[j] != TAG: - Token.set_struct_attr(token, attr_ids[j], array[i, j]) + value = values[j * stride + i] + Token.set_struct_attr(token, attr_ids[j], value) # Set flags self.is_parsed = bool(self.is_parsed or HEAD in attrs) self.is_tagged = bool(self.is_tagged or TAG in attrs or POS in attrs)