diff --git a/spacy/serialize.pyx b/spacy/serialize.pyx index 07f9a95f8..15053b4d4 100644 --- a/spacy/serialize.pyx +++ b/spacy/serialize.pyx @@ -11,6 +11,8 @@ import numpy cimport cython +ctypedef unsigned char uchar + # Format # - Total number of bytes in message (32 bit int) # - Words, terminating in an EOL symbol, huffman coded ~12 bits per word @@ -33,14 +35,29 @@ cdef Code bit_append(Code code, bint bit) nogil: cdef class BitArray: - cdef int length cdef bytes data cdef unsigned char byte cdef unsigned char bit_of_byte + cdef uint32_t i def __init__(self): self.data = b'' self.byte = 0 self.bit_of_byte = 0 + self.i = 0 + + def __iter__(self): + cdef uchar byte, i + cdef uchar one = 1 + start_byte = self.i // 8 + if (self.i % 8) != 0: + for i in range(self.i % 8): + yield (self.data[start_byte] & (one << i)) + start_byte += 1 + for byte in self.data[start_byte:]: + for i in range(8): + yield byte & (one << i) + for i in range(self.bit_of_byte): + yield self.byte & (one << i) def as_bytes(self): if self.bit_of_byte != 0: @@ -48,6 +65,18 @@ cdef class BitArray: else: return self.data + def append(self, bint bit): + cdef uint64_t one = 1 + if bit: + self.byte |= one << self.bit_of_byte + else: + self.byte &= ~(one << self.bit_of_byte) + self.bit_of_byte += 1 + if self.bit_of_byte == 8: + self.data += chr(self.byte) + self.byte = 0 + self.bit_of_byte = 0 + cdef int extend(self, uint64_t code, char n_bits) except -1: cdef uint64_t one = 1 cdef unsigned char bit_of_code @@ -91,31 +120,28 @@ cdef class HuffmanCodec: path.length = 0 assign_codes(self.nodes, self.codes, len(self.nodes) - 1, path) - def encode(self, uint32_t[:] sequence): - cdef BitArray bits = BitArray() + def encode(self, uint32_t[:] sequence, BitArray bits=None): + if bits is None: + bits = BitArray() for i in sequence: bits.extend(self.codes[i].bits, self.codes[i].length) bits.extend(self.codes[self.eol].bits, self.codes[self.eol].length) - return bits.as_bytes() + return bits - def decode(self, bytes data): + def decode(self, BitArray bits): node = self.nodes.back() symbols = [] - cdef unsigned char byte - cdef unsigned char i = 0 - cdef unsigned char one = 1 - for byte in data: - for i in range(8): - branch = node.right if (byte & (one << i)) else node.left - if branch >= 0: - node = self.nodes.at(branch) + for bit in bits: + branch = node.right if bit else node.left + if branch >= 0: + node = self.nodes.at(branch) + else: + symbol = -(branch + 1) + if symbol == self.eol: + return symbols else: - symbol = -(branch + 1) - if symbol == self.eol: - return symbols - else: - symbols.append(symbol) - node = self.nodes.back() + symbols.append(symbol) + node = self.nodes.back() return symbols property strings: