Group-by seems to be working

This commit is contained in:
Matthew Honnibal 2014-07-07 20:27:02 +02:00
parent 80b36f9f27
commit 571808a274
4 changed files with 66 additions and 39 deletions

View File

@ -39,24 +39,18 @@ cdef struct Lexeme:
cdef Lexeme BLANK_WORD = Lexeme(0, 0, NULL, NULL, NULL)
cdef enum StringAttr:
SIC
LEX
NORM
SHAPE
LAST3
cpdef StringHash attr_of(size_t lex_id, StringAttr attr) except 0
cpdef StringHash sic_of(size_t lex_id) except 0
cpdef StringHash lex_of(size_t lex_id) except 0
cpdef StringHash norm_of(size_t lex_id) except 0
cpdef StringHash shape_of(size_t lex_id) except 0
#cdef Lexeme* init_lexeme(Language lang, unicode string, StringHash hashed,
# int split, size_t length)
# Use these to access the Lexeme fields via get_attr(Lexeme*, LexAttr), which
# has a conditional to pick out the correct item. This allows safe iteration
# over the Lexeme, via:
# for field in range(LexAttr.n): get_attr(Lexeme*, field)
cdef enum HashFields:
sic
lex
normed
cluster
n
#cdef uint64_t get_attr(Lexeme* word, HashFields attr)
cpdef StringHash last3_of(size_t lex_id) except 0

View File

@ -13,6 +13,28 @@ from libcpp.vector cimport vector
from spacy.spacy cimport StringHash
# Reiterate the enum, for python
#SIC = StringAttr.sic
#LEX = StringAttr.lex
#NORM = StringAttr.norm
#SHAPE = StringAttr.shape
#LAST3 = StringAttr.last3
cpdef StringHash attr_of(size_t lex_id, StringAttr attr) except 0:
if attr == SIC:
return sic_of(lex_id)
elif attr == LEX:
return lex_of(lex_id)
elif attr == NORM:
return norm_of(lex_id)
elif attr == SHAPE:
return shape_of(lex_id)
elif attr == LAST3:
return last3_of(lex_id)
else:
raise StandardError
cpdef StringHash sic_of(size_t lex_id) except 0:
'''Access the `sic' field of the Lexeme pointed to by lex_id.
@ -58,6 +80,17 @@ cpdef StringHash shape_of(size_t lex_id) except 0:
return (<Lexeme*>lex_id).orth.shape
cpdef StringHash last3_of(size_t lex_id) except 0:
'''Access the `last3' field of the Lexeme pointed to by lex_id, which stores
the hash of the last three characters of the word:
>>> lex_ids = [lookup(w) for w in (u'Hello', u'!')]
>>> [unhash(last3_of(lex_id)) for lex_id in lex_ids]
[u'llo', u'!']
'''
return (<Lexeme*>lex_id).orth.last3
cpdef ClusterID cluster_of(size_t lex_id):
'''Access the `cluster' field of the Lexeme pointed to by lex_id, which
gives an integer representation of the cluster ID of the word,
@ -101,17 +134,6 @@ cpdef double prob_of(size_t lex_id):
return (<Lexeme*>lex_id).dist.prob
cpdef StringHash last3_of(size_t lex_id):
'''Access the `last3' field of the Lexeme pointed to by lex_id, which stores
the hash of the last three characters of the word:
>>> lex_ids = [lookup(w) for w in (u'Hello', u'!')]
>>> [unhash(last3_of(lex_id)) for lex_id in lex_ids]
[u'llo', u'!']
'''
return (<Lexeme*>lex_id).orth.last3
cpdef bint is_oft_upper(size_t lex_id):
'''Access the `oft_upper' field of the Lexeme pointed to by lex_id, which
stores whether the lowered version of the string hashed by `lex' is found

View File

@ -3,9 +3,7 @@ from spacy.spacy cimport Lexeme_addr
from cython.operator cimport dereference as deref
from spacy.spacy cimport Language
cdef enum Field:
lex
from spacy.lexeme cimport StringAttr
cdef class Tokens:
@ -16,5 +14,5 @@ cdef class Tokens:
cpdef int append(self, Lexeme_addr token)
cpdef int extend(self, Tokens other) except -1
cpdef list group_by(self, Field attr)
cpdef dict count_by(self, Field attr)
cpdef list group_by(self, StringAttr attr)
cpdef dict count_by(self, StringAttr attr)

View File

@ -3,7 +3,7 @@ from cython.operator cimport preincrement as inc
from spacy.lexeme cimport Lexeme
from spacy.lexeme cimport norm_of, shape_of
from spacy.lexeme cimport attr_of, norm_of, shape_of
from spacy.spacy cimport StringHash
@ -37,15 +37,28 @@ cdef class Tokens:
for el in other:
self.append(el)
cpdef list group_by(self, Field attr):
pass
cpdef list group_by(self, StringAttr attr):
cdef dict indices = {}
cdef vector[vector[Lexeme_addr]] groups = vector[vector[Lexeme_addr]]()
cpdef dict count_by(self, Field attr):
cdef StringHash key
cdef Lexeme_addr t
for t in self.vctr[0]:
key = attr_of(t, attr)
if key in indices:
groups[indices[key]].push_back(t)
else:
indices[key] = groups.size()
groups.push_back(vector[Lexeme_addr]())
groups.back().push_back(t)
return groups
cpdef dict count_by(self, StringAttr attr):
counts = {}
cdef Lexeme_addr t
cdef StringHash key
for t in self.vctr[0]:
key = (<Lexeme*>t).lex
key = attr_of(t, attr)
if key not in counts:
counts[key] = 0
counts[key] += 1