* Make huffman coder take BitArray in encode/decode. Add __iter__ method to BitArray.

This commit is contained in:
Matthew Honnibal 2015-07-13 17:33:33 +02:00
parent af5cc926a4
commit edd371246c

View File

@ -11,6 +11,8 @@ import numpy
cimport cython cimport cython
ctypedef unsigned char uchar
# Format # Format
# - Total number of bytes in message (32 bit int) # - Total number of bytes in message (32 bit int)
# - Words, terminating in an EOL symbol, huffman coded ~12 bits per word # - Words, terminating in an EOL symbol, huffman coded ~12 bits per word
@ -33,14 +35,29 @@ cdef Code bit_append(Code code, bint bit) nogil:
cdef class BitArray: cdef class BitArray:
cdef int length
cdef bytes data cdef bytes data
cdef unsigned char byte cdef unsigned char byte
cdef unsigned char bit_of_byte cdef unsigned char bit_of_byte
cdef uint32_t i
def __init__(self): def __init__(self):
self.data = b'' self.data = b''
self.byte = 0 self.byte = 0
self.bit_of_byte = 0 self.bit_of_byte = 0
self.i = 0
def __iter__(self):
cdef uchar byte, i
cdef uchar one = 1
start_byte = self.i // 8
if (self.i % 8) != 0:
for i in range(self.i % 8):
yield (self.data[start_byte] & (one << i))
start_byte += 1
for byte in self.data[start_byte:]:
for i in range(8):
yield byte & (one << i)
for i in range(self.bit_of_byte):
yield self.byte & (one << i)
def as_bytes(self): def as_bytes(self):
if self.bit_of_byte != 0: if self.bit_of_byte != 0:
@ -48,6 +65,18 @@ cdef class BitArray:
else: else:
return self.data return self.data
def append(self, bint bit):
cdef uint64_t one = 1
if bit:
self.byte |= one << self.bit_of_byte
else:
self.byte &= ~(one << self.bit_of_byte)
self.bit_of_byte += 1
if self.bit_of_byte == 8:
self.data += chr(self.byte)
self.byte = 0
self.bit_of_byte = 0
cdef int extend(self, uint64_t code, char n_bits) except -1: cdef int extend(self, uint64_t code, char n_bits) except -1:
cdef uint64_t one = 1 cdef uint64_t one = 1
cdef unsigned char bit_of_code cdef unsigned char bit_of_code
@ -91,31 +120,28 @@ cdef class HuffmanCodec:
path.length = 0 path.length = 0
assign_codes(self.nodes, self.codes, len(self.nodes) - 1, path) assign_codes(self.nodes, self.codes, len(self.nodes) - 1, path)
def encode(self, uint32_t[:] sequence): def encode(self, uint32_t[:] sequence, BitArray bits=None):
cdef BitArray bits = BitArray() if bits is None:
bits = BitArray()
for i in sequence: for i in sequence:
bits.extend(self.codes[i].bits, self.codes[i].length) bits.extend(self.codes[i].bits, self.codes[i].length)
bits.extend(self.codes[self.eol].bits, self.codes[self.eol].length) bits.extend(self.codes[self.eol].bits, self.codes[self.eol].length)
return bits.as_bytes() return bits
def decode(self, bytes data): def decode(self, BitArray bits):
node = self.nodes.back() node = self.nodes.back()
symbols = [] symbols = []
cdef unsigned char byte for bit in bits:
cdef unsigned char i = 0 branch = node.right if bit else node.left
cdef unsigned char one = 1 if branch >= 0:
for byte in data: node = self.nodes.at(branch)
for i in range(8): else:
branch = node.right if (byte & (one << i)) else node.left symbol = -(branch + 1)
if branch >= 0: if symbol == self.eol:
node = self.nodes.at(branch) return symbols
else: else:
symbol = -(branch + 1) symbols.append(symbol)
if symbol == self.eol: node = self.nodes.back()
return symbols
else:
symbols.append(symbol)
node = self.nodes.back()
return symbols return symbols
property strings: property strings: