* Progress to getting WordTree working. Tests pass, but so far it's slower.

This commit is contained in:
Matthew Honnibal 2014-08-16 19:59:38 +02:00
parent 865cacfaf7
commit 34b68a18ab
4 changed files with 84 additions and 21 deletions

View File

@ -1,5 +1,7 @@
from libc.stdint cimport uint64_t from libc.stdint cimport uint64_t
from chartree cimport CharTree
cdef class FixedTable: cdef class FixedTable:
cdef size_t size cdef size_t size
@ -9,3 +11,15 @@ cdef class FixedTable:
cdef size_t insert(self, uint64_t key, size_t value) nogil cdef size_t insert(self, uint64_t key, size_t value) nogil
cdef size_t get(self, uint64_t key) nogil cdef size_t get(self, uint64_t key) nogil
cdef int erase(self, uint64_t key) nogil cdef int erase(self, uint64_t key) nogil
cdef class WordTree:
cdef size_t max_length
cdef size_t default
cdef CharTree* _trees
cdef dict _dict
cdef size_t get(self, unicode string) except *
cdef int set(self, unicode string, size_t value) except *
cdef bint contains(self, unicode string) except *

View File

@ -1,6 +1,8 @@
from libc.stdlib cimport calloc, free from libc.stdlib cimport calloc, free
import cython import cython
cimport chartree
cdef class FixedTable: cdef class FixedTable:
def __cinit__(self, const size_t size): def __cinit__(self, const size_t size):
@ -51,3 +53,46 @@ cdef class FixedTable:
@cython.cdivision @cython.cdivision
cdef inline size_t _find(uint64_t key, size_t size) nogil: cdef inline size_t _find(uint64_t key, size_t size) nogil:
return key % size return key % size
cdef class WordTree:
def __cinit__(self, size_t default, size_t max_length):
self.max_length = max_length
self.default = default
self._trees = <CharTree*>calloc(max_length, sizeof(CharTree))
for i in range(self.max_length):
chartree.init(&self._trees[i], i)
self._dict = {}
cdef size_t get(self, unicode ustring) except *:
cdef bytes bstring = ustring.encode('utf8')
cdef size_t length = len(bstring)
if length >= self.max_length:
return self._dict.get(bstring, 0)
else:
return chartree.getitem(&self._trees[length], bstring)
cdef int set(self, unicode ustring, size_t value) except *:
cdef bytes bstring = ustring.encode('utf8')
cdef size_t length = len(bstring)
if length >= self.max_length:
self._dict[bstring] = value
else:
chartree.setitem(&self._trees[length], bstring, value)
cdef bint contains(self, unicode ustring) except *:
cdef bytes bstring = ustring.encode('utf8')
cdef size_t length = len(bstring)
if length >= self.max_length:
return bstring in self._dict
else:
return chartree.contains(&self._trees[length], bstring)
def __getitem__(self, unicode key):
return self.get(key)
def __setitem__(self, unicode key, size_t value):
self.set(key, value)
def __contains__(self, unicode key):
return self.contains(key)

View File

@ -4,6 +4,7 @@ from libc.stdint cimport uint64_t
from sparsehash.dense_hash_map cimport dense_hash_map from sparsehash.dense_hash_map cimport dense_hash_map
from _hashing cimport FixedTable from _hashing cimport FixedTable
from _hashing cimport WordTree
# Circular import problems here # Circular import problems here
ctypedef size_t Lexeme_addr ctypedef size_t Lexeme_addr
@ -22,11 +23,12 @@ ctypedef int ClusterID
from spacy.lexeme cimport Lexeme from spacy.lexeme cimport Lexeme
from spacy.lexeme cimport Distribution from spacy.lexeme cimport Distribution
from spacy.lexeme cimport Orthography from spacy.lexeme cimport Orthography
from spacy._hashing cimport WordTree
cdef class Language: cdef class Language:
cdef object name cdef object name
cdef Vocab* vocab cdef WordTree vocab
cdef Vocab* distri cdef Vocab* distri
cdef Vocab* ortho cdef Vocab* ortho
cdef dict bacov cdef dict bacov
@ -38,7 +40,7 @@ cdef class Language:
cdef Orthography* lookup_orth(self, StringHash key, unicode lex) except NULL cdef Orthography* lookup_orth(self, StringHash key, unicode lex) except NULL
cdef Distribution* lookup_dist(self, StringHash key) except NULL cdef Distribution* lookup_dist(self, StringHash key) except NULL
cdef Lexeme* new_lexeme(self, StringHash key, unicode lex) except NULL cdef Lexeme* new_lexeme(self, unicode key, unicode lex) except NULL
cdef Orthography* new_orth(self, StringHash hashed, unicode lex) except NULL cdef Orthography* new_orth(self, StringHash hashed, unicode lex) except NULL
cdef Distribution* new_dist(self, StringHash key) except NULL cdef Distribution* new_dist(self, StringHash key) except NULL

View File

@ -5,6 +5,7 @@ from libc.stdlib cimport calloc, free
from libcpp.pair cimport pair from libcpp.pair cimport pair
from cython.operator cimport dereference as deref from cython.operator cimport dereference as deref
from murmurhash cimport mrmr
from spacy.lexeme cimport Lexeme from spacy.lexeme cimport Lexeme
from spacy.lexeme cimport BLANK_WORD from spacy.lexeme cimport BLANK_WORD
@ -15,6 +16,13 @@ from os import path
cimport cython cimport cython
#cdef inline StringHash hash_string(unicode string, size_t length):
# '''Hash unicode with MurmurHash64A'''
# return hash(string)
# #cdef bytes byte_string = string.encode('utf8')
# #return mrmr.hash32(<char*>byte_string, len(byte_string) * sizeof(char), 0)
def get_normalized(unicode lex, size_t length): def get_normalized(unicode lex, size_t length):
if lex.isalpha() and lex.islower(): if lex.isalpha() and lex.islower():
return lex return lex
@ -56,10 +64,9 @@ cdef class Language:
def __cinit__(self, name): def __cinit__(self, name):
self.name = name self.name = name
self.bacov = {} self.bacov = {}
self.vocab = new Vocab() self.vocab = WordTree(0, 5)
self.ortho = new Vocab() self.ortho = new Vocab()
self.distri = new Vocab() self.distri = new Vocab()
self.vocab[0].set_empty_key(0)
self.distri[0].set_empty_key(0) self.distri[0].set_empty_key(0)
self.ortho[0].set_empty_key(0) self.ortho[0].set_empty_key(0)
self.load_tokenization(util.read_tokenization(name)) self.load_tokenization(util.read_tokenization(name))
@ -93,9 +100,9 @@ cdef class Language:
cdef StringHash hashed = hash(string) cdef StringHash hashed = hash(string)
# First, check words seen 2+ times # First, check words seen 2+ times
cdef Lexeme* word_ptr = <Lexeme*>self.vocab[0][hashed] cdef Lexeme* word_ptr = <Lexeme*>self.vocab.get(string)
if word_ptr == NULL: if word_ptr == NULL:
word_ptr = self.new_lexeme(hashed, string) word_ptr = self.new_lexeme(string, string)
return <Lexeme_addr>word_ptr return <Lexeme_addr>word_ptr
cdef Lexeme_addr lookup_chunk(self, unicode string) except 0: cdef Lexeme_addr lookup_chunk(self, unicode string) except 0:
@ -106,18 +113,16 @@ cdef class Language:
cdef size_t length = len(string) cdef size_t length = len(string)
if length == 0: if length == 0:
return <Lexeme_addr>&BLANK_WORD return <Lexeme_addr>&BLANK_WORD
cdef StringHash hashed = hash(string)
# First, check words seen 2+ times # First, check words seen 2+ times
cdef Lexeme* word_ptr = <Lexeme*>self.vocab[0][hashed] cdef Lexeme* word_ptr = <Lexeme*>self.vocab.get(string)
cdef int split cdef int split
if word_ptr == NULL: if word_ptr == NULL:
split = self.find_split(string, length) split = self.find_split(string, length)
if split != 0 and split != -1 and split < length: if split != 0 and split != -1 and split < length:
word_ptr = self.new_lexeme(hashed, string[:split]) word_ptr = self.new_lexeme(string, string[:split])
word_ptr.tail = <Lexeme*>self.lookup_chunk(string[split:]) word_ptr.tail = <Lexeme*>self.lookup_chunk(string[split:])
self.bacov[hashed] = string
else: else:
word_ptr = self.new_lexeme(hashed, string) word_ptr = self.new_lexeme(string, string)
return <Lexeme_addr>word_ptr return <Lexeme_addr>word_ptr
cdef Orthography* lookup_orth(self, StringHash hashed, unicode lex): cdef Orthography* lookup_orth(self, StringHash hashed, unicode lex):
@ -132,14 +137,15 @@ cdef class Language:
dist = self.new_dist(hashed) dist = self.new_dist(hashed)
return dist return dist
cdef Lexeme* new_lexeme(self, StringHash key, unicode string) except NULL: cdef Lexeme* new_lexeme(self, unicode key, unicode string) except NULL:
cdef Lexeme* word = <Lexeme*>calloc(1, sizeof(Lexeme)) cdef Lexeme* word = <Lexeme*>calloc(1, sizeof(Lexeme))
word.sic = key word.sic = hash(key)
word.lex = hash(string) word.lex = hash(string)
self.bacov[word.lex] = string self.bacov[word.lex] = string
self.bacov[word.sic] = key
word.orth = self.lookup_orth(word.lex, string) word.orth = self.lookup_orth(word.lex, string)
word.dist = self.lookup_dist(word.lex) word.dist = self.lookup_dist(word.lex)
self.vocab[0][key] = <size_t>word self.vocab.set(key, <size_t>word)
return word return word
cdef Orthography* new_orth(self, StringHash hashed, unicode lex) except NULL: cdef Orthography* new_orth(self, StringHash hashed, unicode lex) except NULL:
@ -185,13 +191,10 @@ cdef class Language:
cdef Lexeme* word cdef Lexeme* word
cdef StringHash hashed cdef StringHash hashed
for chunk, lex, tokens in token_rules: for chunk, lex, tokens in token_rules:
hashed = hash(chunk) word = <Lexeme*>self.new_lexeme(chunk, lex)
word = <Lexeme*>self.new_lexeme(hashed, lex)
for i, lex in enumerate(tokens): for i, lex in enumerate(tokens):
token_string = '%s:@:%d:@:%s' % (chunk, i, lex) token_string = '%s:@:%d:@:%s' % (chunk, i, lex)
length = len(token_string) word.tail = <Lexeme*>self.new_lexeme(token_string, lex)
hashed = hash(token_string)
word.tail = <Lexeme*>self.new_lexeme(hashed, lex)
word = word.tail word = word.tail
def load_clusters(self): def load_clusters(self):
@ -208,8 +211,7 @@ cdef class Language:
# the first 4 bits. See redshift._parse_features.pyx # the first 4 bits. See redshift._parse_features.pyx
cluster = int(cluster_str[::-1], 2) cluster = int(cluster_str[::-1], 2)
upper_pc, title_pc = case_stats.get(token_string.lower(), (0.0, 0.0)) upper_pc, title_pc = case_stats.get(token_string.lower(), (0.0, 0.0))
hashed = hash(token_string) word = self.new_lexeme(token_string, token_string)
word = self.init_lexeme(hashed, token_string)
cdef inline bint _is_whitespace(unsigned char c) nogil: cdef inline bint _is_whitespace(unsigned char c) nogil: