* Make huffman coder take BitArray in encode/decode. Add __iter__ method to BitArray.

This commit is contained in:
Matthew Honnibal 2015-07-13 17:33:33 +02:00
parent af5cc926a4
commit edd371246c

View File

@ -11,6 +11,8 @@ import numpy
cimport cython
ctypedef unsigned char uchar
# Format
# - Total number of bytes in message (32 bit int)
# - Words, terminating in an EOL symbol, huffman coded ~12 bits per word
@ -33,14 +35,29 @@ cdef Code bit_append(Code code, bint bit) nogil:
cdef class BitArray:
cdef int length
cdef bytes data
cdef unsigned char byte
cdef unsigned char bit_of_byte
cdef uint32_t i
def __init__(self):
self.data = b''
self.byte = 0
self.bit_of_byte = 0
self.i = 0
def __iter__(self):
cdef uchar byte, i
cdef uchar one = 1
start_byte = self.i // 8
if (self.i % 8) != 0:
for i in range(self.i % 8):
yield (self.data[start_byte] & (one << i))
start_byte += 1
for byte in self.data[start_byte:]:
for i in range(8):
yield byte & (one << i)
for i in range(self.bit_of_byte):
yield self.byte & (one << i)
def as_bytes(self):
if self.bit_of_byte != 0:
@ -48,6 +65,18 @@ cdef class BitArray:
else:
return self.data
def append(self, bint bit):
cdef uint64_t one = 1
if bit:
self.byte |= one << self.bit_of_byte
else:
self.byte &= ~(one << self.bit_of_byte)
self.bit_of_byte += 1
if self.bit_of_byte == 8:
self.data += chr(self.byte)
self.byte = 0
self.bit_of_byte = 0
cdef int extend(self, uint64_t code, char n_bits) except -1:
cdef uint64_t one = 1
cdef unsigned char bit_of_code
@ -91,22 +120,19 @@ cdef class HuffmanCodec:
path.length = 0
assign_codes(self.nodes, self.codes, len(self.nodes) - 1, path)
def encode(self, uint32_t[:] sequence):
cdef BitArray bits = BitArray()
def encode(self, uint32_t[:] sequence, BitArray bits=None):
if bits is None:
bits = BitArray()
for i in sequence:
bits.extend(self.codes[i].bits, self.codes[i].length)
bits.extend(self.codes[self.eol].bits, self.codes[self.eol].length)
return bits.as_bytes()
return bits
def decode(self, bytes data):
def decode(self, BitArray bits):
node = self.nodes.back()
symbols = []
cdef unsigned char byte
cdef unsigned char i = 0
cdef unsigned char one = 1
for byte in data:
for i in range(8):
branch = node.right if (byte & (one << i)) else node.left
for bit in bits:
branch = node.right if bit else node.left
if branch >= 0:
node = self.nodes.at(branch)
else: