mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 01:34:30 +03:00
Group-by seems to be working
This commit is contained in:
parent
80b36f9f27
commit
571808a274
|
@ -39,24 +39,18 @@ cdef struct Lexeme:
|
||||||
cdef Lexeme BLANK_WORD = Lexeme(0, 0, NULL, NULL, NULL)
|
cdef Lexeme BLANK_WORD = Lexeme(0, 0, NULL, NULL, NULL)
|
||||||
|
|
||||||
|
|
||||||
|
cdef enum StringAttr:
|
||||||
|
SIC
|
||||||
|
LEX
|
||||||
|
NORM
|
||||||
|
SHAPE
|
||||||
|
LAST3
|
||||||
|
|
||||||
|
|
||||||
|
cpdef StringHash attr_of(size_t lex_id, StringAttr attr) except 0
|
||||||
|
|
||||||
|
cpdef StringHash sic_of(size_t lex_id) except 0
|
||||||
cpdef StringHash lex_of(size_t lex_id) except 0
|
cpdef StringHash lex_of(size_t lex_id) except 0
|
||||||
cpdef StringHash norm_of(size_t lex_id) except 0
|
cpdef StringHash norm_of(size_t lex_id) except 0
|
||||||
cpdef StringHash shape_of(size_t lex_id) except 0
|
cpdef StringHash shape_of(size_t lex_id) except 0
|
||||||
#cdef Lexeme* init_lexeme(Language lang, unicode string, StringHash hashed,
|
cpdef StringHash last3_of(size_t lex_id) except 0
|
||||||
# int split, size_t length)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Use these to access the Lexeme fields via get_attr(Lexeme*, LexAttr), which
|
|
||||||
# has a conditional to pick out the correct item. This allows safe iteration
|
|
||||||
# over the Lexeme, via:
|
|
||||||
# for field in range(LexAttr.n): get_attr(Lexeme*, field)
|
|
||||||
cdef enum HashFields:
|
|
||||||
sic
|
|
||||||
lex
|
|
||||||
normed
|
|
||||||
cluster
|
|
||||||
n
|
|
||||||
|
|
||||||
|
|
||||||
#cdef uint64_t get_attr(Lexeme* word, HashFields attr)
|
|
||||||
|
|
|
@ -13,6 +13,28 @@ from libcpp.vector cimport vector
|
||||||
|
|
||||||
from spacy.spacy cimport StringHash
|
from spacy.spacy cimport StringHash
|
||||||
|
|
||||||
|
# Reiterate the enum, for python
|
||||||
|
#SIC = StringAttr.sic
|
||||||
|
#LEX = StringAttr.lex
|
||||||
|
#NORM = StringAttr.norm
|
||||||
|
#SHAPE = StringAttr.shape
|
||||||
|
#LAST3 = StringAttr.last3
|
||||||
|
|
||||||
|
|
||||||
|
cpdef StringHash attr_of(size_t lex_id, StringAttr attr) except 0:
|
||||||
|
if attr == SIC:
|
||||||
|
return sic_of(lex_id)
|
||||||
|
elif attr == LEX:
|
||||||
|
return lex_of(lex_id)
|
||||||
|
elif attr == NORM:
|
||||||
|
return norm_of(lex_id)
|
||||||
|
elif attr == SHAPE:
|
||||||
|
return shape_of(lex_id)
|
||||||
|
elif attr == LAST3:
|
||||||
|
return last3_of(lex_id)
|
||||||
|
else:
|
||||||
|
raise StandardError
|
||||||
|
|
||||||
|
|
||||||
cpdef StringHash sic_of(size_t lex_id) except 0:
|
cpdef StringHash sic_of(size_t lex_id) except 0:
|
||||||
'''Access the `sic' field of the Lexeme pointed to by lex_id.
|
'''Access the `sic' field of the Lexeme pointed to by lex_id.
|
||||||
|
@ -58,6 +80,17 @@ cpdef StringHash shape_of(size_t lex_id) except 0:
|
||||||
return (<Lexeme*>lex_id).orth.shape
|
return (<Lexeme*>lex_id).orth.shape
|
||||||
|
|
||||||
|
|
||||||
|
cpdef StringHash last3_of(size_t lex_id) except 0:
|
||||||
|
'''Access the `last3' field of the Lexeme pointed to by lex_id, which stores
|
||||||
|
the hash of the last three characters of the word:
|
||||||
|
|
||||||
|
>>> lex_ids = [lookup(w) for w in (u'Hello', u'!')]
|
||||||
|
>>> [unhash(last3_of(lex_id)) for lex_id in lex_ids]
|
||||||
|
[u'llo', u'!']
|
||||||
|
'''
|
||||||
|
return (<Lexeme*>lex_id).orth.last3
|
||||||
|
|
||||||
|
|
||||||
cpdef ClusterID cluster_of(size_t lex_id):
|
cpdef ClusterID cluster_of(size_t lex_id):
|
||||||
'''Access the `cluster' field of the Lexeme pointed to by lex_id, which
|
'''Access the `cluster' field of the Lexeme pointed to by lex_id, which
|
||||||
gives an integer representation of the cluster ID of the word,
|
gives an integer representation of the cluster ID of the word,
|
||||||
|
@ -101,17 +134,6 @@ cpdef double prob_of(size_t lex_id):
|
||||||
return (<Lexeme*>lex_id).dist.prob
|
return (<Lexeme*>lex_id).dist.prob
|
||||||
|
|
||||||
|
|
||||||
cpdef StringHash last3_of(size_t lex_id):
|
|
||||||
'''Access the `last3' field of the Lexeme pointed to by lex_id, which stores
|
|
||||||
the hash of the last three characters of the word:
|
|
||||||
|
|
||||||
>>> lex_ids = [lookup(w) for w in (u'Hello', u'!')]
|
|
||||||
>>> [unhash(last3_of(lex_id)) for lex_id in lex_ids]
|
|
||||||
[u'llo', u'!']
|
|
||||||
'''
|
|
||||||
return (<Lexeme*>lex_id).orth.last3
|
|
||||||
|
|
||||||
|
|
||||||
cpdef bint is_oft_upper(size_t lex_id):
|
cpdef bint is_oft_upper(size_t lex_id):
|
||||||
'''Access the `oft_upper' field of the Lexeme pointed to by lex_id, which
|
'''Access the `oft_upper' field of the Lexeme pointed to by lex_id, which
|
||||||
stores whether the lowered version of the string hashed by `lex' is found
|
stores whether the lowered version of the string hashed by `lex' is found
|
||||||
|
|
|
@ -3,9 +3,7 @@ from spacy.spacy cimport Lexeme_addr
|
||||||
|
|
||||||
from cython.operator cimport dereference as deref
|
from cython.operator cimport dereference as deref
|
||||||
from spacy.spacy cimport Language
|
from spacy.spacy cimport Language
|
||||||
|
from spacy.lexeme cimport StringAttr
|
||||||
cdef enum Field:
|
|
||||||
lex
|
|
||||||
|
|
||||||
|
|
||||||
cdef class Tokens:
|
cdef class Tokens:
|
||||||
|
@ -16,5 +14,5 @@ cdef class Tokens:
|
||||||
cpdef int append(self, Lexeme_addr token)
|
cpdef int append(self, Lexeme_addr token)
|
||||||
cpdef int extend(self, Tokens other) except -1
|
cpdef int extend(self, Tokens other) except -1
|
||||||
|
|
||||||
cpdef list group_by(self, Field attr)
|
cpdef list group_by(self, StringAttr attr)
|
||||||
cpdef dict count_by(self, Field attr)
|
cpdef dict count_by(self, StringAttr attr)
|
||||||
|
|
|
@ -3,7 +3,7 @@ from cython.operator cimport preincrement as inc
|
||||||
|
|
||||||
|
|
||||||
from spacy.lexeme cimport Lexeme
|
from spacy.lexeme cimport Lexeme
|
||||||
from spacy.lexeme cimport norm_of, shape_of
|
from spacy.lexeme cimport attr_of, norm_of, shape_of
|
||||||
from spacy.spacy cimport StringHash
|
from spacy.spacy cimport StringHash
|
||||||
|
|
||||||
|
|
||||||
|
@ -37,15 +37,28 @@ cdef class Tokens:
|
||||||
for el in other:
|
for el in other:
|
||||||
self.append(el)
|
self.append(el)
|
||||||
|
|
||||||
cpdef list group_by(self, Field attr):
|
cpdef list group_by(self, StringAttr attr):
|
||||||
pass
|
cdef dict indices = {}
|
||||||
|
cdef vector[vector[Lexeme_addr]] groups = vector[vector[Lexeme_addr]]()
|
||||||
|
|
||||||
cpdef dict count_by(self, Field attr):
|
cdef StringHash key
|
||||||
|
cdef Lexeme_addr t
|
||||||
|
for t in self.vctr[0]:
|
||||||
|
key = attr_of(t, attr)
|
||||||
|
if key in indices:
|
||||||
|
groups[indices[key]].push_back(t)
|
||||||
|
else:
|
||||||
|
indices[key] = groups.size()
|
||||||
|
groups.push_back(vector[Lexeme_addr]())
|
||||||
|
groups.back().push_back(t)
|
||||||
|
return groups
|
||||||
|
|
||||||
|
cpdef dict count_by(self, StringAttr attr):
|
||||||
counts = {}
|
counts = {}
|
||||||
cdef Lexeme_addr t
|
cdef Lexeme_addr t
|
||||||
cdef StringHash key
|
cdef StringHash key
|
||||||
for t in self.vctr[0]:
|
for t in self.vctr[0]:
|
||||||
key = (<Lexeme*>t).lex
|
key = attr_of(t, attr)
|
||||||
if key not in counts:
|
if key not in counts:
|
||||||
counts[key] = 0
|
counts[key] = 0
|
||||||
counts[key] += 1
|
counts[key] += 1
|
||||||
|
|
Loading…
Reference in New Issue
Block a user