mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	Make doc.from_array several times faster
This commit is contained in:
		
							parent
							
								
									de32515bf8
								
							
						
					
					
						commit
						e2279eab1c
					
				|  | @ -806,12 +806,14 @@ cdef class Doc: | ||||||
| 
 | 
 | ||||||
|         if SENT_START in attrs and HEAD in attrs: |         if SENT_START in attrs and HEAD in attrs: | ||||||
|             raise ValueError(Errors.E032) |             raise ValueError(Errors.E032) | ||||||
|         cdef int i, col, abs_head_index |         cdef int i, col | ||||||
|  |         cdef int32_t abs_head_index | ||||||
|         cdef attr_id_t attr_id |         cdef attr_id_t attr_id | ||||||
|         cdef TokenC* tokens = self.c |         cdef TokenC* tokens = self.c | ||||||
|         cdef int length = len(array) |         cdef int length = len(array) | ||||||
|         if length != len(self): |         if length != len(self): | ||||||
|             raise ValueError("Cannot set array values longer than the document.") |             raise ValueError("Cannot set array values longer than the document.") | ||||||
|  | 
 | ||||||
|         # Get set up for fast loading |         # Get set up for fast loading | ||||||
|         cdef Pool mem = Pool() |         cdef Pool mem = Pool() | ||||||
|         cdef int n_attrs = len(attrs) |         cdef int n_attrs = len(attrs) | ||||||
|  | @ -822,33 +824,52 @@ cdef class Doc: | ||||||
|             attr_ids[i] = attr_id |             attr_ids[i] = attr_id | ||||||
|         if len(array.shape) == 1: |         if len(array.shape) == 1: | ||||||
|             array = array.reshape((array.size, 1)) |             array = array.reshape((array.size, 1)) | ||||||
|  |         cdef np.ndarray transposed_array = numpy.ascontiguousarray(array.T) | ||||||
|  |         values = <const uint64_t*>transposed_array.data | ||||||
|  |         stride = transposed_array.shape[1] | ||||||
|         # Check that all heads are within the document bounds |         # Check that all heads are within the document bounds | ||||||
|         if HEAD in attrs: |         if HEAD in attrs: | ||||||
|             col = attrs.index(HEAD) |             col = attrs.index(HEAD) | ||||||
|             for i in range(length): |             for i in range(length): | ||||||
|                 # cast index to signed int |                 # cast index to signed int | ||||||
|                 abs_head_index = numpy.int32(array[i, col]) + i |                 abs_head_index = <int32_t>values[col * stride + i] | ||||||
|  |                 abs_head_index += i | ||||||
|                 if abs_head_index < 0 or abs_head_index >= length: |                 if abs_head_index < 0 or abs_head_index >= length: | ||||||
|                     raise ValueError(Errors.E190.format(index=i, value=array[i, col], rel_head_index=numpy.int32(array[i, col]))) |                     raise ValueError( | ||||||
|  |                         Errors.E190.format( | ||||||
|  |                             index=i, | ||||||
|  |                             value=array[i, col], | ||||||
|  |                             rel_head_index=abs_head_index-i | ||||||
|  |                         ) | ||||||
|  |                     ) | ||||||
|         # Do TAG first. This lets subsequent loop override stuff like POS, LEMMA |         # Do TAG first. This lets subsequent loop override stuff like POS, LEMMA | ||||||
|         if TAG in attrs: |         if TAG in attrs: | ||||||
|             col = attrs.index(TAG) |             col = attrs.index(TAG) | ||||||
|             for i in range(length): |             for i in range(length): | ||||||
|                 if array[i, col] != 0: |                 value = values[col * stride + i] | ||||||
|                     self.vocab.morphology.assign_tag(&tokens[i], array[i, col]) |                 if value != 0: | ||||||
|  |                     self.vocab.morphology.assign_tag(&tokens[i], value) | ||||||
|         # Verify ENT_IOB are proper integers |         # Verify ENT_IOB are proper integers | ||||||
|         if ENT_IOB in attrs: |         if ENT_IOB in attrs: | ||||||
|             iob_strings = Token.iob_strings() |             iob_strings = Token.iob_strings() | ||||||
|             col = attrs.index(ENT_IOB) |             col = attrs.index(ENT_IOB) | ||||||
|  |             n_iob_strings = len(iob_strings) | ||||||
|             for i in range(length): |             for i in range(length): | ||||||
|                 if array[i, col] not in range(0, len(iob_strings)): |                 value = values[col * stride + i] | ||||||
|                     raise ValueError(Errors.E982.format(values=iob_strings, value=array[i, col])) |                 if value < 0 or value >= n_iob_strings: | ||||||
|  |                     raise ValueError( | ||||||
|  |                         Errors.E982.format( | ||||||
|  |                             values=iob_strings, | ||||||
|  |                             value=value | ||||||
|  |                         ) | ||||||
|  |                     ) | ||||||
|         # Now load the data |         # Now load the data | ||||||
|         for i in range(length): |         for i in range(length): | ||||||
|             token = &self.c[i] |             token = &self.c[i] | ||||||
|             for j in range(n_attrs): |             for j in range(n_attrs): | ||||||
|                 if attr_ids[j] != TAG: |                 if attr_ids[j] != TAG: | ||||||
|                     Token.set_struct_attr(token, attr_ids[j], array[i, j]) |                     value = values[j * stride + i] | ||||||
|  |                     Token.set_struct_attr(token, attr_ids[j], value) | ||||||
|         # Set flags |         # Set flags | ||||||
|         self.is_parsed = bool(self.is_parsed or HEAD in attrs) |         self.is_parsed = bool(self.is_parsed or HEAD in attrs) | ||||||
|         self.is_tagged = bool(self.is_tagged or TAG in attrs or POS in attrs) |         self.is_tagged = bool(self.is_tagged or TAG in attrs or POS in attrs) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user