Fix doc.count_by functionality (#3950)

Fix doc.count_by functionality
This commit is contained in:
Ines Montani 2019-07-11 13:44:00 +02:00 committed by GitHub
commit 673c864a06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 36 additions and 7 deletions

View File

@ -5,7 +5,6 @@ import logging
from pathlib import Path from pathlib import Path
from collections import defaultdict from collections import defaultdict
from gensim.models import Word2Vec from gensim.models import Word2Vec
from preshed.counter import PreshCounter
import plac import plac
import spacy import spacy

View File

@ -0,0 +1,31 @@
# coding: utf8
from __future__ import unicode_literals
import pytest
from spacy.attrs import IS_ALPHA
from spacy.lang.en import English
@pytest.mark.parametrize(
"sentence",
[
'The story was to the effect that a young American student recently called on Professor Christlieb with a letter of introduction.',
'The next month Barry Siddall joined Stoke City on a free transfer, after Chris Pearce had established himself as the Vale\'s #1.',
'The next month Barry Siddall joined Stoke City on a free transfer, after Chris Pearce had established himself as the Vale\'s number one',
'Indeed, making the one who remains do all the work has installed him into a position of such insolent tyranny, it will take a month at least to reduce him to his proper proportions.',
"It was a missed assignment, but it shouldn't have resulted in a turnover ..."
],
)
def test_issue3869(sentence):
"""Test that the Doc's count_by function works consistently"""
nlp = English()
doc = nlp(sentence)
count = 0
for token in doc:
count += token.is_alpha
assert count == doc.count_by(IS_ALPHA).get(1, 0)

View File

@ -1,6 +1,5 @@
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
cimport numpy as np cimport numpy as np
from preshed.counter cimport PreshCounter
from ..vocab cimport Vocab from ..vocab cimport Vocab
from ..structs cimport TokenC, LexemeC from ..structs cimport TokenC, LexemeC

View File

@ -9,6 +9,7 @@ cimport cython
cimport numpy as np cimport numpy as np
from libc.string cimport memcpy, memset from libc.string cimport memcpy, memset
from libc.math cimport sqrt from libc.math cimport sqrt
from collections import Counter
import numpy import numpy
import numpy.linalg import numpy.linalg
@ -697,7 +698,7 @@ cdef class Doc:
# Handle 1d case # Handle 1d case
return output if len(attr_ids) >= 2 else output.reshape((self.length,)) return output if len(attr_ids) >= 2 else output.reshape((self.length,))
def count_by(self, attr_id_t attr_id, exclude=None, PreshCounter counts=None): def count_by(self, attr_id_t attr_id, exclude=None, object counts=None):
"""Count the frequencies of a given attribute. Produces a dict of """Count the frequencies of a given attribute. Produces a dict of
`{attribute (int): count (ints)}` frequencies, keyed by the values of `{attribute (int): count (ints)}` frequencies, keyed by the values of
the given attribute ID. the given attribute ID.
@ -712,19 +713,18 @@ cdef class Doc:
cdef size_t count cdef size_t count
if counts is None: if counts is None:
counts = PreshCounter() counts = Counter()
output_dict = True output_dict = True
else: else:
output_dict = False output_dict = False
# Take this check out of the loop, for a bit of extra speed # Take this check out of the loop, for a bit of extra speed
if exclude is None: if exclude is None:
for i in range(self.length): for i in range(self.length):
counts.inc(get_token_attr(&self.c[i], attr_id), 1) counts[get_token_attr(&self.c[i], attr_id)] += 1
else: else:
for i in range(self.length): for i in range(self.length):
if not exclude(self[i]): if not exclude(self[i]):
attr = get_token_attr(&self.c[i], attr_id) counts[get_token_attr(&self.c[i], attr_id)] += 1
counts.inc(attr, 1)
if output_dict: if output_dict:
return dict(counts) return dict(counts)