diff --git a/spacy/serialize/packer.pyx b/spacy/serialize/packer.pyx index 2f9305646..0edcbde04 100644 --- a/spacy/serialize/packer.pyx +++ b/spacy/serialize/packer.pyx @@ -7,7 +7,7 @@ from libcpp.pair cimport pair from cymem.cymem cimport Address, Pool from preshed.maps cimport PreshMap -from ..attrs cimport ID, SPACY, TAG, HEAD, DEP, ENT_IOB, ENT_TYPE +from ..attrs cimport ID, ORTH, SPACY, TAG, HEAD, DEP, ENT_IOB, ENT_TYPE from ..tokens.doc cimport Doc from ..vocab cimport Vocab from ..typedefs cimport attr_t @@ -45,14 +45,14 @@ def make_vocab_codec(Vocab vocab): cdef class _BinaryCodec: - def encode(self, src, bits): + def encode(self, attr_t[:] msg, BitArray bits): cdef int i - for i in range(len(src)): - bits.append(src[i]) + for i in range(len(msg)): + bits.append(msg[i]) - def decode(self, dest, bits, n): - for i in range(n): - dest[i] = bits.next() + def decode(self, bits, attr_t[:] msg): + for i in range(len(msg)): + msg[i] = bits.next() cdef class _AttributeCodec: @@ -62,6 +62,7 @@ cdef class _AttributeCodec: cdef HuffmanCodec _codec def __init__(self, freqs): + self.mem = Pool() cdef uint64_t key cdef uint64_t count cdef pair[uint64_t, uint64_t] item @@ -72,28 +73,30 @@ cdef class _AttributeCodec: item.first = count item.second = key items.push(item) - weights = numpy.array(shape=(len(freqs),), dtype=numpy.float32) + weights = numpy.ndarray(shape=(len(freqs),), dtype=numpy.float32) self._keys = self.mem.alloc(len(freqs), sizeof(attr_t)) self._map = PreshMap() cdef int i = 0 while not items.empty(): item = items.top() - weights[i] = item.first + # We put freq first above, for sorting self._keys[i] = item.second - self._map[self.keys[i]] = i + weights[i] = item.first + self._map[self._keys[i]] = i items.pop() + i += 1 self._codec = HuffmanCodec(weights) - def encode(self, attr_t[:] msg, BitArray into_bits): + def encode(self, attr_t[:] msg, BitArray dest): for i in range(len(msg)): - msg[i] = self._map[msg[i]] - self._codec.encode(msg, into_bits) + msg[i] = self._map[msg[i]] + self._codec.encode(msg, dest) - def decode(self, BitArray bits, attr_t[:] into_msg): + def decode(self, BitArray bits, attr_t[:] dest): cdef int i - self._codec.decode(bits, into_msg) - for i in range(len(into_msg)): - into_msg[i] = self._keys[into_msg[i]] + self._codec.decode(bits, dest) + for i in range(len(dest)): + dest[i] = self._keys[dest[i]] cdef class Packer: @@ -103,7 +106,7 @@ cdef class Packer: self.attrs = [] for attr, freqs in list_of_attr_freqs: - if attr == ID: + if attr == ORTH: codecs.append(make_vocab_codec(vocab)) elif attr == SPACY: codecs.append(_BinaryCodec()) @@ -112,15 +115,8 @@ cdef class Packer: self.attrs.append(attr) self._codecs = tuple(codecs) - def __call__(self, msg_or_bits): - if isinstance(msg_or_bits, BitArray): - bits = msg_or_bits - return Doc.from_array(self.vocab, self.attrs, self.deserialize(bits)) - else: - msg = msg_or_bits - return self.serialize(msg.to_array(self.attrs)) - - def serialize(self, array): + def pack(self, Doc doc): + array = doc.to_array(self.attrs) cdef BitArray bits = BitArray() cdef uint32_t length = len(array) bits.extend(length, 32) @@ -128,9 +124,9 @@ cdef class Packer: codec.encode(array[i], bits) return bits - def deserialize(self, bits): + def unpack(self, bits): cdef uint32_t length = bits.read(32) array = numpy.ndarray(shape=(len(self.codecs), length), dtype=numpy.int) for i, codec in enumerate(self.codecs): array[i] = codec.decode(bits) - return array + return Doc.from_array(self.vocab, self.attrs, array)