diff --git a/spacy/serialize.pxd b/spacy/serialize.pxd index df7f0fc65..d060382a4 100644 --- a/spacy/serialize.pxd +++ b/spacy/serialize.pxd @@ -4,6 +4,8 @@ from libc.stdint cimport int64_t from libc.stdint cimport int32_t from libc.stdint cimport uint64_t +from .vocab cimport Vocab + cdef struct Node: float prob @@ -18,6 +20,7 @@ cdef struct Code: cdef class Serializer: cdef list codecs + cdef Vocab vocab cdef class HuffmanCodec: diff --git a/spacy/serialize.pyx b/spacy/serialize.pyx index a354203d4..de217f74e 100644 --- a/spacy/serialize.pyx +++ b/spacy/serialize.pyx @@ -3,10 +3,15 @@ from libc.stdint cimport uint32_t from libc.stdint cimport int64_t from libc.stdint cimport int32_t from libc.stdint cimport uint64_t +from libcpp.queue cimport priority_queue +from libcpp.pair cimport pair from preshed.maps cimport PreshMap from murmurhash.mrmr cimport hash64 +from .tokens.doc cimport Doc +from .vocab cimport Vocab +from os import path import numpy cimport cython @@ -97,7 +102,7 @@ cdef class Serializer: def __init__(self, Vocab vocab, data_dir): model_dir = path.join(data_dir, 'bitter') self.vocab = vocab # Vocab owns the word codec, the big one - self.cfg = Config.read(model_dir, 'config') + #self.cfg = Config.read(model_dir, 'config') self.codecs = tuple([CodecWrapper(attr) for attr in self.cfg.attrs]) def __call__(self, doc_or_bits): @@ -129,7 +134,7 @@ cdef class Serializer: cdef bint is_spacy for id_ in ids: is_spacy = biterator.next() - doc.push_back(vocab.lexemes.at(id_), is_spacy) + doc.push_back(self.vocab.lexemes.at(id_), is_spacy) cdef int length = doc.length array = numpy.zeros(shape=(length, len(self.codecs)), dtype=numpy.int) @@ -139,20 +144,20 @@ cdef class Serializer: return doc -cdef class AttributeEncoder: +cdef class CodecWrapper: """Wrapper around HuffmanCodec""" def __init__(self, freqs, id=0): cdef uint64_t key cdef uint64_t count - cdef pair[uint64_t] item - cdef priority_queue[pair[uint64_t]] items + cdef pair[uint64_t, uint64_t] item + cdef priority_queue[pair[uint64_t, uint64_t]] items for key, count in freqs: item.first = count item.second = key items.push(item) - weights = array('f') - keys = array('i') + weights = [] #array('f') + keys = [] #array('i') key_to_i = PreshMap() i = 0 while not items.empty(): @@ -188,8 +193,8 @@ cdef class HuffmanCodec: eol (uint32_t): The index of the weight of the EOL symbol. """ - def __init__(self, float[:] weights, unt32_t eol): - self.codes.resize(len(probs)) + def __init__(self, float[:] weights, uint32_t eol): + self.codes.resize(len(weights)) for i in range(len(self.codes)): self.codes[i].bits = 0 self.codes[i].length = 0