From 571808a274e220d05883a3bab63245d94bd767e0 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 7 Jul 2014 20:27:02 +0200 Subject: [PATCH] Group-by seems to be working --- spacy/lexeme.pxd | 30 ++++++++++++------------------ spacy/lexeme.pyx | 44 +++++++++++++++++++++++++++++++++----------- spacy/tokens.pxd | 8 +++----- spacy/tokens.pyx | 23 ++++++++++++++++++----- 4 files changed, 66 insertions(+), 39 deletions(-) diff --git a/spacy/lexeme.pxd b/spacy/lexeme.pxd index 041bdcc47..50417e65a 100644 --- a/spacy/lexeme.pxd +++ b/spacy/lexeme.pxd @@ -39,24 +39,18 @@ cdef struct Lexeme: cdef Lexeme BLANK_WORD = Lexeme(0, 0, NULL, NULL, NULL) +cdef enum StringAttr: + SIC + LEX + NORM + SHAPE + LAST3 + + +cpdef StringHash attr_of(size_t lex_id, StringAttr attr) except 0 + +cpdef StringHash sic_of(size_t lex_id) except 0 cpdef StringHash lex_of(size_t lex_id) except 0 cpdef StringHash norm_of(size_t lex_id) except 0 cpdef StringHash shape_of(size_t lex_id) except 0 -#cdef Lexeme* init_lexeme(Language lang, unicode string, StringHash hashed, -# int split, size_t length) - - - -# Use these to access the Lexeme fields via get_attr(Lexeme*, LexAttr), which -# has a conditional to pick out the correct item. This allows safe iteration -# over the Lexeme, via: -# for field in range(LexAttr.n): get_attr(Lexeme*, field) -cdef enum HashFields: - sic - lex - normed - cluster - n - - -#cdef uint64_t get_attr(Lexeme* word, HashFields attr) +cpdef StringHash last3_of(size_t lex_id) except 0 diff --git a/spacy/lexeme.pyx b/spacy/lexeme.pyx index 551b88442..e769a6bee 100644 --- a/spacy/lexeme.pyx +++ b/spacy/lexeme.pyx @@ -13,6 +13,28 @@ from libcpp.vector cimport vector from spacy.spacy cimport StringHash +# Reiterate the enum, for python +#SIC = StringAttr.sic +#LEX = StringAttr.lex +#NORM = StringAttr.norm +#SHAPE = StringAttr.shape +#LAST3 = StringAttr.last3 + + +cpdef StringHash attr_of(size_t lex_id, StringAttr attr) except 0: + if attr == SIC: + return sic_of(lex_id) + elif attr == LEX: + return lex_of(lex_id) + elif attr == NORM: + return norm_of(lex_id) + elif attr == SHAPE: + return shape_of(lex_id) + elif attr == LAST3: + return last3_of(lex_id) + else: + raise StandardError + cpdef StringHash sic_of(size_t lex_id) except 0: '''Access the `sic' field of the Lexeme pointed to by lex_id. @@ -58,6 +80,17 @@ cpdef StringHash shape_of(size_t lex_id) except 0: return (lex_id).orth.shape +cpdef StringHash last3_of(size_t lex_id) except 0: + '''Access the `last3' field of the Lexeme pointed to by lex_id, which stores + the hash of the last three characters of the word: + + >>> lex_ids = [lookup(w) for w in (u'Hello', u'!')] + >>> [unhash(last3_of(lex_id)) for lex_id in lex_ids] + [u'llo', u'!'] + ''' + return (lex_id).orth.last3 + + cpdef ClusterID cluster_of(size_t lex_id): '''Access the `cluster' field of the Lexeme pointed to by lex_id, which gives an integer representation of the cluster ID of the word, @@ -101,17 +134,6 @@ cpdef double prob_of(size_t lex_id): return (lex_id).dist.prob -cpdef StringHash last3_of(size_t lex_id): - '''Access the `last3' field of the Lexeme pointed to by lex_id, which stores - the hash of the last three characters of the word: - - >>> lex_ids = [lookup(w) for w in (u'Hello', u'!')] - >>> [unhash(last3_of(lex_id)) for lex_id in lex_ids] - [u'llo', u'!'] - ''' - return (lex_id).orth.last3 - - cpdef bint is_oft_upper(size_t lex_id): '''Access the `oft_upper' field of the Lexeme pointed to by lex_id, which stores whether the lowered version of the string hashed by `lex' is found diff --git a/spacy/tokens.pxd b/spacy/tokens.pxd index db6ebe008..5b640c85a 100644 --- a/spacy/tokens.pxd +++ b/spacy/tokens.pxd @@ -3,9 +3,7 @@ from spacy.spacy cimport Lexeme_addr from cython.operator cimport dereference as deref from spacy.spacy cimport Language - -cdef enum Field: - lex +from spacy.lexeme cimport StringAttr cdef class Tokens: @@ -16,5 +14,5 @@ cdef class Tokens: cpdef int append(self, Lexeme_addr token) cpdef int extend(self, Tokens other) except -1 - cpdef list group_by(self, Field attr) - cpdef dict count_by(self, Field attr) + cpdef list group_by(self, StringAttr attr) + cpdef dict count_by(self, StringAttr attr) diff --git a/spacy/tokens.pyx b/spacy/tokens.pyx index 9d40ceb26..0b5c80370 100644 --- a/spacy/tokens.pyx +++ b/spacy/tokens.pyx @@ -3,7 +3,7 @@ from cython.operator cimport preincrement as inc from spacy.lexeme cimport Lexeme -from spacy.lexeme cimport norm_of, shape_of +from spacy.lexeme cimport attr_of, norm_of, shape_of from spacy.spacy cimport StringHash @@ -37,15 +37,28 @@ cdef class Tokens: for el in other: self.append(el) - cpdef list group_by(self, Field attr): - pass + cpdef list group_by(self, StringAttr attr): + cdef dict indices = {} + cdef vector[vector[Lexeme_addr]] groups = vector[vector[Lexeme_addr]]() - cpdef dict count_by(self, Field attr): + cdef StringHash key + cdef Lexeme_addr t + for t in self.vctr[0]: + key = attr_of(t, attr) + if key in indices: + groups[indices[key]].push_back(t) + else: + indices[key] = groups.size() + groups.push_back(vector[Lexeme_addr]()) + groups.back().push_back(t) + return groups + + cpdef dict count_by(self, StringAttr attr): counts = {} cdef Lexeme_addr t cdef StringHash key for t in self.vctr[0]: - key = (t).lex + key = attr_of(t, attr) if key not in counts: counts[key] = 0 counts[key] += 1