diff --git a/spacy/serialize.pxd b/spacy/serialize.pxd index 4c81ccccd..df7f0fc65 100644 --- a/spacy/serialize.pxd +++ b/spacy/serialize.pxd @@ -16,8 +16,13 @@ cdef struct Code: char length +cdef class Serializer: + cdef list codecs + + cdef class HuffmanCodec: cdef vector[Node] nodes cdef vector[Code] codes cdef uint32_t eol + cdef int id diff --git a/spacy/serialize.pyx b/spacy/serialize.pyx index 3e78bb59e..1ea58092d 100644 --- a/spacy/serialize.pyx +++ b/spacy/serialize.pyx @@ -93,6 +93,84 @@ cdef class BitArray: self.bit_of_byte = 0 +cdef class Serializer: + # Manage codecs, maintain consistent format for io + def __init__(self, Vocab vocab, model_dir): + self.vocab = vocab + self.lex = None + self.codecs = [] + + def __call__(self, doc_or_bits): + if isinstance(doc_or_bits, Doc): + return self.serialize(doc_or_bits) + elif isinstance(doc_or_bits, BitArray): + return self.deserialize(doc_or_bits) + else: + raise ValueError(doc_or_bits) + + def train(self, doc): + array = doc.to_array(self.attrs) + for i, attr in enumerate(self.attrs): + for j in range(doc.length): + self.freqs[attr].inc(array[i, j], 1) + self.freqs[attr].inc(self.eol, 1) + + def serialize(self, doc): + bits = BitArray() + array = doc.to_array(self.attrs) + for i, attr in enumerate(self.attrs, self.codecs): + codec.encode(array[i,], bits) + return bits + + @cython.boundscheck(False) + def deserialize(self, bits): + cdef Doc doc = Doc(self.vocab) + biterator = iter(bits) + ids = self.codecs[0].decode(bits) + cdef int id_ + cdef bint is_spacy + for id_ in ids: + is_spacy = biterator.next() + doc.push_back(vocab.lexemes.at(id_), is_spacy) + + cdef int length = doc.length + cdef int i + cdef attr_t value + cdef attr_id_t attr_id + cdef attr_t[:] values + cdef TokenC* tokens = doc.data + for codec in vocab.codecs[1:]: + values = codec.decode(biterator) + attr_id = codec.id + if attr_id == HEAD: + for i in range(length): + tokens[i].head = values[i] + elif attr_id == TAG: + for i in range(length): + tokens[i].tag = values[i] + elif attr_id == DEP: + for i in range(length): + tokens[i].dep = values[i] + elif attr_id == ENT_IOB: + for i in range(length): + tokens[i].ent_iob = values[i] + elif attr_id == ENT_TYPE: + for i in range(length): + tokens[i].ent_type = values[i] + return doc + + def lex_codec(self): + cdef Address mem + cdef int i + cdef float[:] cv_probs + mem = Address(len(self), sizeof(float)) + probs = mem.ptr + for i in range(len(self.vocab)): + probs[i] = c_exp(self.lexemes[i].prob) + cv_probs = probs + return HuffmanCodec(cv_probs, 0, id=ID) + + cdef class HuffmanCodec: """Create a Huffman code table, and use it to pack and unpack sequences into byte strings. Emphasis is on efficiency, so API is quite strict: @@ -109,7 +187,8 @@ cdef class HuffmanCodec: eol (uint32_t): The index of the weight of the EOL symbol. """ - def __init__(self, float[:] probs, uint32_t eol): + def __init__(self, float[:] probs, uint32_t eol, id=0): + self.id = id self.eol = eol self.codes.resize(len(probs)) for i in range(len(self.codes)):