diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index 6e7230428..9351ba366 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -557,28 +557,25 @@ cdef class Doc: """ cdef int i, j cdef attr_id_t feature + cdef np.ndarray[attr_t, ndim=1] attr_ids cdef np.ndarray[attr_t, ndim=2] output - cdef np.ndarray[attr_t, ndim=1] output_1D # Handle scalar/list inputs of strings/ints for py_attr_ids - if( type(py_attr_ids) is not list and type(py_attr_ids) is not tuple ): - py_attr_ids = [ py_attr_ids ] - py_attr_ids_input = [] - for py_attr_id in py_attr_ids: - if( type(py_attr_id) is int ): - py_attr_ids_input.append(py_attr_id) - else: - py_attr_ids_input.append(IDS[py_attr_id.upper()]) + if not hasattr(py_attr_ids, '__iter__'): + py_attr_ids = [py_attr_ids] + + # Allow strings, e.g. 'lemma' or 'LEMMA' + py_attr_ids = [(IDS[id_.upper()] if hasattr(id_, 'upper') else id_) + for id_ in py_attr_ids] # Make an array from the attributes --- otherwise our inner loop is Python # dict iteration. - cdef np.ndarray[attr_t, ndim=1] attr_ids = numpy.asarray(py_attr_ids_input, dtype=numpy.uint64) + attr_ids = numpy.asarray(py_attr_ids, dtype=numpy.uint64) output = numpy.ndarray(shape=(self.length, len(attr_ids)), dtype=numpy.uint64) for i in range(self.length): for j, feature in enumerate(attr_ids): output[i, j] = get_token_attr(&self.c[i], feature) - if( len(attr_ids) == 1 ): - output_1D = output.reshape((self.length)) - return output_1D - return output + # Handle 1d case + return output if len(attr_ids) >= 2 else output.reshape((self.length,)) + def count_by(self, attr_id_t attr_id, exclude=None, PreshCounter counts=None): """Count the frequencies of a given attribute. Produces a dict of