mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 01:16:28 +03:00
Improve efficiency of Doc.to_array
This commit is contained in:
parent
2acc907d55
commit
e10e9ad2c5
|
@ -1,6 +1,7 @@
|
|||
# coding: utf8
|
||||
# cython: infer_types=True
|
||||
# cython: bounds_check=False
|
||||
# cython: profile=True
|
||||
from __future__ import unicode_literals
|
||||
|
||||
cimport cython
|
||||
|
@ -567,7 +568,6 @@ cdef class Doc:
|
|||
"""
|
||||
cdef int i, j
|
||||
cdef attr_id_t feature
|
||||
cdef np.ndarray[attr_t, ndim=1] attr_ids
|
||||
cdef np.ndarray[attr_t, ndim=2] output
|
||||
# Handle scalar/list inputs of strings/ints for py_attr_ids
|
||||
if not hasattr(py_attr_ids, '__iter__') \
|
||||
|
@ -579,12 +579,17 @@ cdef class Doc:
|
|||
for id_ in py_attr_ids]
|
||||
# Make an array from the attributes --- otherwise our inner loop is
|
||||
# Python dict iteration.
|
||||
attr_ids = numpy.asarray(py_attr_ids, dtype=numpy.uint64)
|
||||
cdef np.ndarray attr_ids = numpy.asarray(py_attr_ids, dtype='i')
|
||||
output = numpy.ndarray(shape=(self.length, len(attr_ids)),
|
||||
dtype=numpy.uint64)
|
||||
c_output = <attr_t*>output.data
|
||||
c_attr_ids = <attr_id_t*>attr_ids.data
|
||||
cdef TokenC* token
|
||||
cdef int nr_attr = attr_ids.shape[0]
|
||||
for i in range(self.length):
|
||||
for j, feature in enumerate(attr_ids):
|
||||
output[i, j] = get_token_attr(&self.c[i], feature)
|
||||
token = &self.c[i]
|
||||
for j in range(nr_attr):
|
||||
c_output[i*nr_attr + j] = get_token_attr(token, c_attr_ids[j])
|
||||
# Handle 1d case
|
||||
return output if len(attr_ids) >= 2 else output.reshape((self.length,))
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user