mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	Make doc.from_array several times faster
This commit is contained in:
		
							parent
							
								
									de32515bf8
								
							
						
					
					
						commit
						e2279eab1c
					
				|  | @ -806,12 +806,14 @@ cdef class Doc: | |||
| 
 | ||||
|         if SENT_START in attrs and HEAD in attrs: | ||||
|             raise ValueError(Errors.E032) | ||||
|         cdef int i, col, abs_head_index | ||||
|         cdef int i, col | ||||
|         cdef int32_t abs_head_index | ||||
|         cdef attr_id_t attr_id | ||||
|         cdef TokenC* tokens = self.c | ||||
|         cdef int length = len(array) | ||||
|         if length != len(self): | ||||
|             raise ValueError("Cannot set array values longer than the document.") | ||||
| 
 | ||||
|         # Get set up for fast loading | ||||
|         cdef Pool mem = Pool() | ||||
|         cdef int n_attrs = len(attrs) | ||||
|  | @ -822,33 +824,52 @@ cdef class Doc: | |||
|             attr_ids[i] = attr_id | ||||
|         if len(array.shape) == 1: | ||||
|             array = array.reshape((array.size, 1)) | ||||
|         cdef np.ndarray transposed_array = numpy.ascontiguousarray(array.T) | ||||
|         values = <const uint64_t*>transposed_array.data | ||||
|         stride = transposed_array.shape[1] | ||||
|         # Check that all heads are within the document bounds | ||||
|         if HEAD in attrs: | ||||
|             col = attrs.index(HEAD) | ||||
|             for i in range(length): | ||||
|                 # cast index to signed int | ||||
|                 abs_head_index = numpy.int32(array[i, col]) + i | ||||
|                 abs_head_index = <int32_t>values[col * stride + i] | ||||
|                 abs_head_index += i | ||||
|                 if abs_head_index < 0 or abs_head_index >= length: | ||||
|                     raise ValueError(Errors.E190.format(index=i, value=array[i, col], rel_head_index=numpy.int32(array[i, col]))) | ||||
|                     raise ValueError( | ||||
|                         Errors.E190.format( | ||||
|                             index=i, | ||||
|                             value=array[i, col], | ||||
|                             rel_head_index=abs_head_index-i | ||||
|                         ) | ||||
|                     ) | ||||
|         # Do TAG first. This lets subsequent loop override stuff like POS, LEMMA | ||||
|         if TAG in attrs: | ||||
|             col = attrs.index(TAG) | ||||
|             for i in range(length): | ||||
|                 if array[i, col] != 0: | ||||
|                     self.vocab.morphology.assign_tag(&tokens[i], array[i, col]) | ||||
|                 value = values[col * stride + i] | ||||
|                 if value != 0: | ||||
|                     self.vocab.morphology.assign_tag(&tokens[i], value) | ||||
|         # Verify ENT_IOB are proper integers | ||||
|         if ENT_IOB in attrs: | ||||
|             iob_strings = Token.iob_strings() | ||||
|             col = attrs.index(ENT_IOB) | ||||
|             n_iob_strings = len(iob_strings) | ||||
|             for i in range(length): | ||||
|                 if array[i, col] not in range(0, len(iob_strings)): | ||||
|                     raise ValueError(Errors.E982.format(values=iob_strings, value=array[i, col])) | ||||
|                 value = values[col * stride + i] | ||||
|                 if value < 0 or value >= n_iob_strings: | ||||
|                     raise ValueError( | ||||
|                         Errors.E982.format( | ||||
|                             values=iob_strings, | ||||
|                             value=value | ||||
|                         ) | ||||
|                     ) | ||||
|         # Now load the data | ||||
|         for i in range(length): | ||||
|             token = &self.c[i] | ||||
|             for j in range(n_attrs): | ||||
|                 if attr_ids[j] != TAG: | ||||
|                     Token.set_struct_attr(token, attr_ids[j], array[i, j]) | ||||
|                     value = values[j * stride + i] | ||||
|                     Token.set_struct_attr(token, attr_ids[j], value) | ||||
|         # Set flags | ||||
|         self.is_parsed = bool(self.is_parsed or HEAD in attrs) | ||||
|         self.is_tagged = bool(self.is_tagged or TAG in attrs or POS in attrs) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user