mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 09:14:32 +03:00
Group-by seems to be working
This commit is contained in:
parent
80b36f9f27
commit
571808a274
|
@ -39,24 +39,18 @@ cdef struct Lexeme:
|
|||
cdef Lexeme BLANK_WORD = Lexeme(0, 0, NULL, NULL, NULL)
|
||||
|
||||
|
||||
cdef enum StringAttr:
|
||||
SIC
|
||||
LEX
|
||||
NORM
|
||||
SHAPE
|
||||
LAST3
|
||||
|
||||
|
||||
cpdef StringHash attr_of(size_t lex_id, StringAttr attr) except 0
|
||||
|
||||
cpdef StringHash sic_of(size_t lex_id) except 0
|
||||
cpdef StringHash lex_of(size_t lex_id) except 0
|
||||
cpdef StringHash norm_of(size_t lex_id) except 0
|
||||
cpdef StringHash shape_of(size_t lex_id) except 0
|
||||
#cdef Lexeme* init_lexeme(Language lang, unicode string, StringHash hashed,
|
||||
# int split, size_t length)
|
||||
|
||||
|
||||
|
||||
# Use these to access the Lexeme fields via get_attr(Lexeme*, LexAttr), which
|
||||
# has a conditional to pick out the correct item. This allows safe iteration
|
||||
# over the Lexeme, via:
|
||||
# for field in range(LexAttr.n): get_attr(Lexeme*, field)
|
||||
cdef enum HashFields:
|
||||
sic
|
||||
lex
|
||||
normed
|
||||
cluster
|
||||
n
|
||||
|
||||
|
||||
#cdef uint64_t get_attr(Lexeme* word, HashFields attr)
|
||||
cpdef StringHash last3_of(size_t lex_id) except 0
|
||||
|
|
|
@ -13,6 +13,28 @@ from libcpp.vector cimport vector
|
|||
|
||||
from spacy.spacy cimport StringHash
|
||||
|
||||
# Reiterate the enum, for python
|
||||
#SIC = StringAttr.sic
|
||||
#LEX = StringAttr.lex
|
||||
#NORM = StringAttr.norm
|
||||
#SHAPE = StringAttr.shape
|
||||
#LAST3 = StringAttr.last3
|
||||
|
||||
|
||||
cpdef StringHash attr_of(size_t lex_id, StringAttr attr) except 0:
|
||||
if attr == SIC:
|
||||
return sic_of(lex_id)
|
||||
elif attr == LEX:
|
||||
return lex_of(lex_id)
|
||||
elif attr == NORM:
|
||||
return norm_of(lex_id)
|
||||
elif attr == SHAPE:
|
||||
return shape_of(lex_id)
|
||||
elif attr == LAST3:
|
||||
return last3_of(lex_id)
|
||||
else:
|
||||
raise StandardError
|
||||
|
||||
|
||||
cpdef StringHash sic_of(size_t lex_id) except 0:
|
||||
'''Access the `sic' field of the Lexeme pointed to by lex_id.
|
||||
|
@ -58,6 +80,17 @@ cpdef StringHash shape_of(size_t lex_id) except 0:
|
|||
return (<Lexeme*>lex_id).orth.shape
|
||||
|
||||
|
||||
cpdef StringHash last3_of(size_t lex_id) except 0:
|
||||
'''Access the `last3' field of the Lexeme pointed to by lex_id, which stores
|
||||
the hash of the last three characters of the word:
|
||||
|
||||
>>> lex_ids = [lookup(w) for w in (u'Hello', u'!')]
|
||||
>>> [unhash(last3_of(lex_id)) for lex_id in lex_ids]
|
||||
[u'llo', u'!']
|
||||
'''
|
||||
return (<Lexeme*>lex_id).orth.last3
|
||||
|
||||
|
||||
cpdef ClusterID cluster_of(size_t lex_id):
|
||||
'''Access the `cluster' field of the Lexeme pointed to by lex_id, which
|
||||
gives an integer representation of the cluster ID of the word,
|
||||
|
@ -101,17 +134,6 @@ cpdef double prob_of(size_t lex_id):
|
|||
return (<Lexeme*>lex_id).dist.prob
|
||||
|
||||
|
||||
cpdef StringHash last3_of(size_t lex_id):
|
||||
'''Access the `last3' field of the Lexeme pointed to by lex_id, which stores
|
||||
the hash of the last three characters of the word:
|
||||
|
||||
>>> lex_ids = [lookup(w) for w in (u'Hello', u'!')]
|
||||
>>> [unhash(last3_of(lex_id)) for lex_id in lex_ids]
|
||||
[u'llo', u'!']
|
||||
'''
|
||||
return (<Lexeme*>lex_id).orth.last3
|
||||
|
||||
|
||||
cpdef bint is_oft_upper(size_t lex_id):
|
||||
'''Access the `oft_upper' field of the Lexeme pointed to by lex_id, which
|
||||
stores whether the lowered version of the string hashed by `lex' is found
|
||||
|
|
|
@ -3,9 +3,7 @@ from spacy.spacy cimport Lexeme_addr
|
|||
|
||||
from cython.operator cimport dereference as deref
|
||||
from spacy.spacy cimport Language
|
||||
|
||||
cdef enum Field:
|
||||
lex
|
||||
from spacy.lexeme cimport StringAttr
|
||||
|
||||
|
||||
cdef class Tokens:
|
||||
|
@ -16,5 +14,5 @@ cdef class Tokens:
|
|||
cpdef int append(self, Lexeme_addr token)
|
||||
cpdef int extend(self, Tokens other) except -1
|
||||
|
||||
cpdef list group_by(self, Field attr)
|
||||
cpdef dict count_by(self, Field attr)
|
||||
cpdef list group_by(self, StringAttr attr)
|
||||
cpdef dict count_by(self, StringAttr attr)
|
||||
|
|
|
@ -3,7 +3,7 @@ from cython.operator cimport preincrement as inc
|
|||
|
||||
|
||||
from spacy.lexeme cimport Lexeme
|
||||
from spacy.lexeme cimport norm_of, shape_of
|
||||
from spacy.lexeme cimport attr_of, norm_of, shape_of
|
||||
from spacy.spacy cimport StringHash
|
||||
|
||||
|
||||
|
@ -37,15 +37,28 @@ cdef class Tokens:
|
|||
for el in other:
|
||||
self.append(el)
|
||||
|
||||
cpdef list group_by(self, Field attr):
|
||||
pass
|
||||
cpdef list group_by(self, StringAttr attr):
|
||||
cdef dict indices = {}
|
||||
cdef vector[vector[Lexeme_addr]] groups = vector[vector[Lexeme_addr]]()
|
||||
|
||||
cpdef dict count_by(self, Field attr):
|
||||
cdef StringHash key
|
||||
cdef Lexeme_addr t
|
||||
for t in self.vctr[0]:
|
||||
key = attr_of(t, attr)
|
||||
if key in indices:
|
||||
groups[indices[key]].push_back(t)
|
||||
else:
|
||||
indices[key] = groups.size()
|
||||
groups.push_back(vector[Lexeme_addr]())
|
||||
groups.back().push_back(t)
|
||||
return groups
|
||||
|
||||
cpdef dict count_by(self, StringAttr attr):
|
||||
counts = {}
|
||||
cdef Lexeme_addr t
|
||||
cdef StringHash key
|
||||
for t in self.vctr[0]:
|
||||
key = (<Lexeme*>t).lex
|
||||
key = attr_of(t, attr)
|
||||
if key not in counts:
|
||||
counts[key] = 0
|
||||
counts[key] += 1
|
||||
|
|
Loading…
Reference in New Issue
Block a user