mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-26 05:31:15 +03:00 
			
		
		
		
	Improve efficiency of Doc.to_array
This commit is contained in:
		
							parent
							
								
									2acc907d55
								
							
						
					
					
						commit
						e10e9ad2c5
					
				|  | @ -1,6 +1,7 @@ | |||
| # coding: utf8 | ||||
| # cython: infer_types=True | ||||
| # cython: bounds_check=False | ||||
| # cython: profile=True | ||||
| from __future__ import unicode_literals | ||||
| 
 | ||||
| cimport cython | ||||
|  | @ -567,7 +568,6 @@ cdef class Doc: | |||
|         """ | ||||
|         cdef int i, j | ||||
|         cdef attr_id_t feature | ||||
|         cdef np.ndarray[attr_t, ndim=1] attr_ids | ||||
|         cdef np.ndarray[attr_t, ndim=2] output | ||||
|         # Handle scalar/list inputs of strings/ints for py_attr_ids | ||||
|         if not hasattr(py_attr_ids, '__iter__') \ | ||||
|  | @ -579,12 +579,17 @@ cdef class Doc: | |||
|                        for id_ in py_attr_ids] | ||||
|         # Make an array from the attributes --- otherwise our inner loop is | ||||
|         # Python dict iteration. | ||||
|         attr_ids = numpy.asarray(py_attr_ids, dtype=numpy.uint64) | ||||
|         cdef np.ndarray attr_ids = numpy.asarray(py_attr_ids, dtype='i') | ||||
|         output = numpy.ndarray(shape=(self.length, len(attr_ids)), | ||||
|                                dtype=numpy.uint64) | ||||
|         c_output = <attr_t*>output.data | ||||
|         c_attr_ids = <attr_id_t*>attr_ids.data | ||||
|         cdef TokenC* token | ||||
|         cdef int nr_attr = attr_ids.shape[0] | ||||
|         for i in range(self.length): | ||||
|             for j, feature in enumerate(attr_ids): | ||||
|                 output[i, j] = get_token_attr(&self.c[i], feature) | ||||
|             token = &self.c[i] | ||||
|             for j in range(nr_attr): | ||||
|                 c_output[i*nr_attr + j] = get_token_attr(token, c_attr_ids[j]) | ||||
|         # Handle 1d case | ||||
|         return output if len(attr_ids) >= 2 else output.reshape((self.length,)) | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user