Fix doc.count_by functionality (#3950)

Fix doc.count_by functionality
This commit is contained in:
Ines Montani 2019-07-11 13:44:00 +02:00 committed by GitHub
commit 673c864a06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 36 additions and 7 deletions

View File

@ -5,7 +5,6 @@ import logging
from pathlib import Path
from collections import defaultdict
from gensim.models import Word2Vec
from preshed.counter import PreshCounter
import plac
import spacy

View File

@ -0,0 +1,31 @@
# coding: utf8
from __future__ import unicode_literals
import pytest
from spacy.attrs import IS_ALPHA
from spacy.lang.en import English
@pytest.mark.parametrize(
"sentence",
[
'The story was to the effect that a young American student recently called on Professor Christlieb with a letter of introduction.',
'The next month Barry Siddall joined Stoke City on a free transfer, after Chris Pearce had established himself as the Vale\'s #1.',
'The next month Barry Siddall joined Stoke City on a free transfer, after Chris Pearce had established himself as the Vale\'s number one',
'Indeed, making the one who remains do all the work has installed him into a position of such insolent tyranny, it will take a month at least to reduce him to his proper proportions.',
"It was a missed assignment, but it shouldn't have resulted in a turnover ..."
],
)
def test_issue3869(sentence):
"""Test that the Doc's count_by function works consistently"""
nlp = English()
doc = nlp(sentence)
count = 0
for token in doc:
count += token.is_alpha
assert count == doc.count_by(IS_ALPHA).get(1, 0)

View File

@ -1,6 +1,5 @@
from cymem.cymem cimport Pool
cimport numpy as np
from preshed.counter cimport PreshCounter
from ..vocab cimport Vocab
from ..structs cimport TokenC, LexemeC

View File

@ -9,6 +9,7 @@ cimport cython
cimport numpy as np
from libc.string cimport memcpy, memset
from libc.math cimport sqrt
from collections import Counter
import numpy
import numpy.linalg
@ -697,7 +698,7 @@ cdef class Doc:
# Handle 1d case
return output if len(attr_ids) >= 2 else output.reshape((self.length,))
def count_by(self, attr_id_t attr_id, exclude=None, PreshCounter counts=None):
def count_by(self, attr_id_t attr_id, exclude=None, object counts=None):
"""Count the frequencies of a given attribute. Produces a dict of
`{attribute (int): count (ints)}` frequencies, keyed by the values of
the given attribute ID.
@ -712,19 +713,18 @@ cdef class Doc:
cdef size_t count
if counts is None:
counts = PreshCounter()
counts = Counter()
output_dict = True
else:
output_dict = False
# Take this check out of the loop, for a bit of extra speed
if exclude is None:
for i in range(self.length):
counts.inc(get_token_attr(&self.c[i], attr_id), 1)
counts[get_token_attr(&self.c[i], attr_id)] += 1
else:
for i in range(self.length):
if not exclude(self[i]):
attr = get_token_attr(&self.c[i], attr_id)
counts.inc(attr, 1)
counts[get_token_attr(&self.c[i], attr_id)] += 1
if output_dict:
return dict(counts)