* Extend count_by method

This commit is contained in:
Matthew Honnibal 2015-07-14 03:20:09 +02:00
parent 39c93116eb
commit 935ac53ee3

View File

@ -218,7 +218,7 @@ cdef class Doc:
output[i, j] = get_token_attr(&self.data[i], feature) output[i, j] = get_token_attr(&self.data[i], feature)
return output return output
def count_by(self, attr_id_t attr_id, exclude=None): def count_by(self, attr_id_t attr_id, exclude=None, PreshCounter counts=None):
"""Produce a dict of {attribute (int): count (ints)} frequencies, keyed """Produce a dict of {attribute (int): count (ints)} frequencies, keyed
by the values of the given attribute ID. by the values of the given attribute ID.
@ -237,12 +237,22 @@ cdef class Doc:
cdef attr_t attr cdef attr_t attr
cdef size_t count cdef size_t count
cdef PreshCounter counts = PreshCounter(2 ** 8) if counts is None:
counts = PreshCounter(self.length)
output_dict = True
else:
output_dict = False
# Take this check out of the loop, for a bit of extra speed
if exclude is None:
for i in range(self.length): for i in range(self.length):
if exclude is not None and exclude(self[i]):
continue
attr = get_token_attr(&self.data[i], attr_id) attr = get_token_attr(&self.data[i], attr_id)
counts.inc(attr, 1) counts.inc(attr, 1)
else:
for i in range(self.length):
if not exclude(self[i]):
attr = get_token_attr(&self.data[i], attr_id)
counts.inc(attr, 1)
if output_dict:
return dict(counts) return dict(counts)
def _realloc(self, new_size): def _realloc(self, new_size):