mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
* Extend count_by method
This commit is contained in:
parent
39c93116eb
commit
935ac53ee3
|
@ -218,7 +218,7 @@ cdef class Doc:
|
|||
output[i, j] = get_token_attr(&self.data[i], feature)
|
||||
return output
|
||||
|
||||
def count_by(self, attr_id_t attr_id, exclude=None):
|
||||
def count_by(self, attr_id_t attr_id, exclude=None, PreshCounter counts=None):
|
||||
"""Produce a dict of {attribute (int): count (ints)} frequencies, keyed
|
||||
by the values of the given attribute ID.
|
||||
|
||||
|
@ -237,12 +237,22 @@ cdef class Doc:
|
|||
cdef attr_t attr
|
||||
cdef size_t count
|
||||
|
||||
cdef PreshCounter counts = PreshCounter(2 ** 8)
|
||||
if counts is None:
|
||||
counts = PreshCounter(self.length)
|
||||
output_dict = True
|
||||
else:
|
||||
output_dict = False
|
||||
# Take this check out of the loop, for a bit of extra speed
|
||||
if exclude is None:
|
||||
for i in range(self.length):
|
||||
if exclude is not None and exclude(self[i]):
|
||||
continue
|
||||
attr = get_token_attr(&self.data[i], attr_id)
|
||||
counts.inc(attr, 1)
|
||||
else:
|
||||
for i in range(self.length):
|
||||
if not exclude(self[i]):
|
||||
attr = get_token_attr(&self.data[i], attr_id)
|
||||
counts.inc(attr, 1)
|
||||
if output_dict:
|
||||
return dict(counts)
|
||||
|
||||
def _realloc(self, new_size):
|
||||
|
|
Loading…
Reference in New Issue
Block a user