* Nearly finished huffman coder

This commit is contained in:
Matthew Honnibal 2015-07-12 23:48:46 +02:00
parent e1a25fba32
commit 281f1faefb

View File

@ -11,19 +11,12 @@ import numpy
cimport cython cimport cython
# Format
#cdef class Serializer: # - Total number of bytes in message (32 bit int)
# def __init__(self, Vocab vocab): # - Words, terminating in an EOL symbol, huffman coded ~12 bits per word
# pass # - Spaces ~1 bit per word
# # - Parse: Huffman coded head offset / dep label / POS tag / entity IOB tag
# def dump(self, Doc tokens, file_): # combo. ? bits per word. 40 * 80 * 40 * 12 = 1.5m symbol vocab
# pass
# # Format
# # - Total number of bytes in message (32 bit int)
# # - Words, terminating in an EOL symbol, huffman coded ~12 bits per word
# # - Spaces ~1 bit per word
# # - Parse: Huffman coded head offset / dep label / POS tag / entity IOB tag
# # combo. ? bits per word. 40 * 80 * 40 * 12 = 1.5m symbol vocab
cdef struct Node: cdef struct Node:
@ -53,21 +46,11 @@ cdef Code bit_append(Code code, bint bit) nogil:
cdef class HuffmanCodec: cdef class HuffmanCodec:
cdef vector[Node] nodes cdef vector[Node] nodes
cdef vector[Code] codes cdef vector[Code] codes
cdef float[:] probs cdef readonly float[:] probs
cdef PreshMap table cdef PreshMap table
def __init__(self, symbols, probs): cdef uint32_t eol
self.table = PreshMap() def __init__(self, probs, eol):
cdef bytes symb_str self.eol = eol
cdef uint64_t key
cdef uint32_t i
for i, symbol in enumerate(symbols):
if type(symbol) == unicode or type(symbol) == bytes:
symb_str = symbol.encode('utf8')
key = hash64(<unsigned char*>symb_str, len(symb_str), 0)
else:
key = int(symbol)
self.table[key] = i+1
self.symbols = symbols
self.probs = probs self.probs = probs
self.codes.resize(len(probs)) self.codes.resize(len(probs))
for i in range(len(self.codes)): for i in range(len(self.codes)):
@ -79,45 +62,45 @@ cdef class HuffmanCodec:
path.length = 0 path.length = 0
assign_codes(self.nodes, self.codes, len(self.nodes) - 1, path) assign_codes(self.nodes, self.codes, len(self.nodes) - 1, path)
def encode(self, sequence): def encode(self, uint32_t[:] sequence):
cdef vector[bint] bits cdef Code code
cdef uint64_t key cdef bytes output = b''
cdef uint64_t i cdef unsigned char byte = 0
for symbol in sequence: cdef uint64_t one = 1
if type(symbol) == unicode or type(symbol) == bytes: cdef unsigned char i_of_byte = 0
symb_str = symbol.encode('utf8') cdef unsigned char i_of_code = 0
key = hash64(<unsigned char*>symb_str, len(symb_str), 0) for index in sequence:
code = self.codes[index]
for i_of_code in range(code.length):
if code.bits & (one << i_of_code):
byte |= one << i_of_byte
else: else:
key = int(symbol) byte &= ~(one << i_of_byte)
i = <uint32_t>self.table.get(key) i_of_byte += 1
if i == 0: if i_of_byte == 8:
raise Exception("Unseen symbol: %s" % symbol) output += chr(byte)
else: byte = 0
code = self.codes[i] i_of_byte = 0
bits.extend(code) if i_of_byte != 0:
return bits output += chr(byte)
return output
def decode(self, unsigned char[:] data): def decode(self, bytes data):
symbols = []
node = self.nodes.back() node = self.nodes.back()
bits = [] symbols = []
cdef unsigned char byte cdef unsigned char byte
cdef unsigned char one cdef unsigned char i = 0
cdef int i = 0 cdef unsigned char one = 1
for byte_ in data: for byte in data:
for i in range(7, -1, -1): for i in range(8):
bits.append(bool(byte & (one << i))) branch = node.right if (byte & (one << i)) else node.left
cdef bint bit = 0
for bit in bits:
branch = node.right if bit else node.left
if branch >= 0: if branch >= 0:
node = self.nodes.at(branch) node = self.nodes.at(branch)
else: else:
symbol = self.symbols[-(branch + 1)] symbol = -(branch + 1)
if symbol == self.eol_symbol: if symbol == self.eol:
break return symbols
else:
symbols.append(symbol) symbols.append(symbol)
node = self.nodes.back() node = self.nodes.back()
return symbols return symbols