* Move serialization into Serializer class, with __call__ and train() api

This commit is contained in:
Matthew Honnibal 2015-07-16 11:22:35 +02:00
parent e2133d990e
commit 6c99e5f4aa
2 changed files with 85 additions and 1 deletions

View File

@ -16,8 +16,13 @@ cdef struct Code:
char length char length
cdef class Serializer:
cdef list codecs
cdef class HuffmanCodec: cdef class HuffmanCodec:
cdef vector[Node] nodes cdef vector[Node] nodes
cdef vector[Code] codes cdef vector[Code] codes
cdef uint32_t eol cdef uint32_t eol
cdef int id

View File

@ -93,6 +93,84 @@ cdef class BitArray:
self.bit_of_byte = 0 self.bit_of_byte = 0
cdef class Serializer:
# Manage codecs, maintain consistent format for io
def __init__(self, Vocab vocab, model_dir):
self.vocab = vocab
self.lex = None
self.codecs = []
def __call__(self, doc_or_bits):
if isinstance(doc_or_bits, Doc):
return self.serialize(doc_or_bits)
elif isinstance(doc_or_bits, BitArray):
return self.deserialize(doc_or_bits)
else:
raise ValueError(doc_or_bits)
def train(self, doc):
array = doc.to_array(self.attrs)
for i, attr in enumerate(self.attrs):
for j in range(doc.length):
self.freqs[attr].inc(array[i, j], 1)
self.freqs[attr].inc(self.eol, 1)
def serialize(self, doc):
bits = BitArray()
array = doc.to_array(self.attrs)
for i, attr in enumerate(self.attrs, self.codecs):
codec.encode(array[i,], bits)
return bits
@cython.boundscheck(False)
def deserialize(self, bits):
cdef Doc doc = Doc(self.vocab)
biterator = iter(bits)
ids = self.codecs[0].decode(bits)
cdef int id_
cdef bint is_spacy
for id_ in ids:
is_spacy = biterator.next()
doc.push_back(vocab.lexemes.at(id_), is_spacy)
cdef int length = doc.length
cdef int i
cdef attr_t value
cdef attr_id_t attr_id
cdef attr_t[:] values
cdef TokenC* tokens = doc.data
for codec in vocab.codecs[1:]:
values = codec.decode(biterator)
attr_id = codec.id
if attr_id == HEAD:
for i in range(length):
tokens[i].head = values[i]
elif attr_id == TAG:
for i in range(length):
tokens[i].tag = values[i]
elif attr_id == DEP:
for i in range(length):
tokens[i].dep = values[i]
elif attr_id == ENT_IOB:
for i in range(length):
tokens[i].ent_iob = values[i]
elif attr_id == ENT_TYPE:
for i in range(length):
tokens[i].ent_type = values[i]
return doc
def lex_codec(self):
cdef Address mem
cdef int i
cdef float[:] cv_probs
mem = Address(len(self), sizeof(float))
probs = <float*>mem.ptr
for i in range(len(self.vocab)):
probs[i] = <float>c_exp(self.lexemes[i].prob)
cv_probs = <float[:len(self)]>probs
return HuffmanCodec(cv_probs, 0, id=ID)
cdef class HuffmanCodec: cdef class HuffmanCodec:
"""Create a Huffman code table, and use it to pack and unpack sequences into """Create a Huffman code table, and use it to pack and unpack sequences into
byte strings. Emphasis is on efficiency, so API is quite strict: byte strings. Emphasis is on efficiency, so API is quite strict:
@ -109,7 +187,8 @@ cdef class HuffmanCodec:
eol (uint32_t): The index of the weight of the EOL symbol. eol (uint32_t): The index of the weight of the EOL symbol.
""" """
def __init__(self, float[:] probs, uint32_t eol): def __init__(self, float[:] probs, uint32_t eol, id=0):
self.id = id
self.eol = eol self.eol = eol
self.codes.resize(len(probs)) self.codes.resize(len(probs))
for i in range(len(self.codes)): for i in range(len(self.codes)):