* Fixed group_by, removed idea of general attr_of function.

This commit is contained in:
Matthew Honnibal 2014-08-22 00:02:37 +02:00
parent 811b7a6b91
commit 07ecf5d2f4
4 changed files with 16 additions and 50 deletions

View File

@ -10,30 +10,14 @@ from spacy.tokens cimport Tokens
cimport cython cimport cython
ctypedef fused AttrType:
ClusterID
StringHash
cython.char
cdef enum AttrName:
LEX
FIRST
LENGTH
CLUSTER
NORM
SHAPE
LAST3
cdef class English(spacy.Language): cdef class English(spacy.Language):
cdef int find_split(self, unicode word) cdef int find_split(self, unicode word)
cdef int set_orth(self, unicode word, Lexeme* lex) except -1 cdef int set_orth(self, unicode word, Lexeme* lex) except -1
cdef AttrType attr_of(self, LexID lex_id, AttrName attr) except *
cdef English EN cdef English EN
cpdef LexID lookup(unicode word) except 0 cpdef LexID lookup(unicode word) except 0
cpdef Tokens tokenize(unicode string) cpdef Tokens tokenize(unicode string)
cpdef unicode unhash(StringHash hash_value) cpdef unicode unhash(StringHash hash_value)

View File

@ -76,27 +76,6 @@ cdef class English(spacy.Language):
i += 1 i += 1
return i return i
cdef AttrType attr_of(self, LexID lex_id, AttrName attr) except *:
cdef Lexeme* w = <Lexeme*>lex_id
if attr == LEX:
return <AttrType>w.lex
elif attr == FIRST:
return w.string[0]
elif attr == LENGTH:
return w.length
elif attr == CLUSTER:
return w.cluster
elif attr == NORM:
return w.string_views[0]
elif attr == SHAPE:
return w.string_views[1]
elif attr == LAST3:
return w.string_views[2]
else:
raise AttributeError(attr)
cdef bint check_punct(unicode word, size_t i, size_t length): cdef bint check_punct(unicode word, size_t i, size_t length):
# Don't count appostrophes as punct if the next char is a letter # Don't count appostrophes as punct if the next char is a letter

View File

@ -1,5 +1,5 @@
from libcpp.vector cimport vector from libcpp.vector cimport vector
from spacy.spacy cimport Lexeme_addr from spacy.lexeme cimport LexID
from spacy.lexeme cimport Lexeme from spacy.lexeme cimport Lexeme
from cython.operator cimport dereference as deref from cython.operator cimport dereference as deref
@ -8,10 +8,10 @@ from spacy.spacy cimport Language
cdef class Tokens: cdef class Tokens:
cdef Language lang cdef Language lang
cdef vector[Lexeme_addr]* vctr cdef vector[LexID]* vctr
cdef size_t length cdef size_t length
cpdef int append(self, Lexeme_addr token) cpdef int append(self, LexID token)
cpdef int extend(self, Tokens other) except -1 cpdef int extend(self, Tokens other) except -1
cpdef object group_by(self, size_t attr) cpdef object group_by(self, size_t attr)

View File

@ -9,14 +9,14 @@ from spacy.spacy cimport StringHash
cdef class Tokens: cdef class Tokens:
def __cinit__(self, Language lang): def __cinit__(self, Language lang):
self.lang = lang self.lang = lang
self.vctr = new vector[Lexeme_addr]() self.vctr = new vector[LexID]()
self.length = 0 self.length = 0
def __dealloc__(self): def __dealloc__(self):
del self.vctr del self.vctr
def __iter__(self): def __iter__(self):
cdef vector[Lexeme_addr].iterator it = self.vctr[0].begin() cdef vector[LexID].iterator it = self.vctr[0].begin()
while it != self.vctr[0].end(): while it != self.vctr[0].end():
yield deref(it) yield deref(it)
inc(it) inc(it)
@ -27,16 +27,16 @@ cdef class Tokens:
def __len__(self): def __len__(self):
return self.length return self.length
cpdef int append(self, Lexeme_addr token): cpdef int append(self, LexID token):
self.vctr[0].push_back(token) self.vctr[0].push_back(token)
self.length += 1 self.length += 1
cpdef int extend(self, Tokens other) except -1: cpdef int extend(self, Tokens other) except -1:
cdef Lexeme_addr el cdef LexID el
for el in other: for el in other:
self.append(el) self.append(el)
cpdef object group_by(self, size_t attr): cpdef object group_by(self, size_t view_idx):
'''Group tokens that share the property attr into Tokens instances, and '''Group tokens that share the property attr into Tokens instances, and
return a list of them. Returns a tuple of three lists: return a list of them. Returns a tuple of three lists:
@ -63,9 +63,12 @@ cdef class Tokens:
cdef list hashes = [] cdef list hashes = []
cdef StringHash key cdef StringHash key
cdef Lexeme_addr t cdef LexID t
for t in self.vctr[0]: for t in self.vctr[0]:
key = self.lang.attr_of(t, attr) if view_idx == 0:
key = (<Lexeme*>t).lex
else:
key = (<Lexeme*>t).string_views[view_idx - 1]
if key in indices: if key in indices:
groups[indices[key]].append(t) groups[indices[key]].append(t)
else: else:
@ -78,7 +81,7 @@ cdef class Tokens:
cpdef dict count_by(self, size_t attr): cpdef dict count_by(self, size_t attr):
counts = {} counts = {}
cdef Lexeme_addr t cdef LexID t
cdef StringHash key cdef StringHash key
for t in self.vctr[0]: for t in self.vctr[0]:
#key = attr_of(t, attr) #key = attr_of(t, attr)