cleanup to_array implementation using fixes on master

This commit is contained in:
Ramanan Balakrishnan 2017-10-20 17:09:37 +05:30
parent d44a079fe3
commit 0726946563
No known key found for this signature in database
GPG Key ID: 57283041B6B6D1D1

View File

@ -557,28 +557,25 @@ cdef class Doc:
""" """
cdef int i, j cdef int i, j
cdef attr_id_t feature cdef attr_id_t feature
cdef np.ndarray[attr_t, ndim=1] attr_ids
cdef np.ndarray[attr_t, ndim=2] output cdef np.ndarray[attr_t, ndim=2] output
cdef np.ndarray[attr_t, ndim=1] output_1D
# Handle scalar/list inputs of strings/ints for py_attr_ids # Handle scalar/list inputs of strings/ints for py_attr_ids
if( type(py_attr_ids) is not list and type(py_attr_ids) is not tuple ): if not hasattr(py_attr_ids, '__iter__'):
py_attr_ids = [ py_attr_ids ] py_attr_ids = [py_attr_ids]
py_attr_ids_input = []
for py_attr_id in py_attr_ids: # Allow strings, e.g. 'lemma' or 'LEMMA'
if( type(py_attr_id) is int ): py_attr_ids = [(IDS[id_.upper()] if hasattr(id_, 'upper') else id_)
py_attr_ids_input.append(py_attr_id) for id_ in py_attr_ids]
else:
py_attr_ids_input.append(IDS[py_attr_id.upper()])
# Make an array from the attributes --- otherwise our inner loop is Python # Make an array from the attributes --- otherwise our inner loop is Python
# dict iteration. # dict iteration.
cdef np.ndarray[attr_t, ndim=1] attr_ids = numpy.asarray(py_attr_ids_input, dtype=numpy.uint64) attr_ids = numpy.asarray(py_attr_ids, dtype=numpy.uint64)
output = numpy.ndarray(shape=(self.length, len(attr_ids)), dtype=numpy.uint64) output = numpy.ndarray(shape=(self.length, len(attr_ids)), dtype=numpy.uint64)
for i in range(self.length): for i in range(self.length):
for j, feature in enumerate(attr_ids): for j, feature in enumerate(attr_ids):
output[i, j] = get_token_attr(&self.c[i], feature) output[i, j] = get_token_attr(&self.c[i], feature)
if( len(attr_ids) == 1 ): # Handle 1d case
output_1D = output.reshape((self.length)) return output if len(attr_ids) >= 2 else output.reshape((self.length,))
return output_1D
return output
def count_by(self, attr_id_t attr_id, exclude=None, PreshCounter counts=None): def count_by(self, attr_id_t attr_id, exclude=None, PreshCounter counts=None):
"""Count the frequencies of a given attribute. Produces a dict of """Count the frequencies of a given attribute. Produces a dict of