* Add faster encode_int32 and decode_int32 methods

This commit is contained in:
Matthew Honnibal 2015-07-21 19:58:45 +02:00
parent dd60594f41
commit c6cd0ddce8
2 changed files with 50 additions and 3 deletions

View File

@ -4,7 +4,7 @@ from libc.stdint cimport int64_t
from libc.stdint cimport int32_t from libc.stdint cimport int32_t
from libc.stdint cimport uint64_t from libc.stdint cimport uint64_t
from .bits cimport Code from .bits cimport BitArray, Code
cdef struct Node: cdef struct Node:
@ -19,3 +19,6 @@ cdef class HuffmanCodec:
cdef readonly list leaves cdef readonly list leaves
cdef readonly dict _map cdef readonly dict _map
cpdef int encode_int32(self, int32_t[:] msg, BitArray bits) except -1
cpdef int decode_int32(self, BitArray bits, int32_t[:] msg) except -1

View File

@ -61,6 +61,19 @@ cdef class HuffmanCodec:
bits.extend(self.codes[i].bits, self.codes[i].length) bits.extend(self.codes[i].bits, self.codes[i].length)
return bits return bits
cpdef int encode_int32(self, int32_t[:] msg, BitArray bits) except -1:
cdef int msg_i
cdef int leaf_i
cdef int length = 0
for msg_i in range(msg.shape[0]):
leaf_i = self._map.get(msg[msg_i], -1)
if leaf_i is -1:
return 0
code = self.codes[leaf_i]
bits.extend(code.bits, code.length)
length += code.length
return length
def n_bits(self, msg, overhead=0): def n_bits(self, msg, overhead=0):
cdef int i cdef int i
length = 0 length = 0
@ -88,8 +101,39 @@ cdef class HuffmanCodec:
if i == n: if i == n:
break break
else: else:
raise Exception( raise Exception("Buffer exhausted at %d/%d symbols read." % (i, len(msg)))
"Buffer exhausted at %d/%d symbols read." % (i, len(msg)))
@cython.boundscheck(False)
cpdef int decode_int32(self, BitArray bits, int32_t[:] msg) except -1:
cdef Node node = self.root
cdef int branch
cdef int n_msg = msg.shape[0]
cdef bytes bytes_ = bits.as_bytes()
cdef unsigned char byte
cdef int i_msg = 0
cdef int i_byte = 0
cdef int i_bit = 0
cdef unsigned char bit
cdef int32_t one = 1
while i_msg < n_msg:
byte = bytes_[i_byte]
for i_bit in range(8):
bit = byte & (one << i_bit)
branch = node.right if bit else node.left
if branch >= 0:
node = self.nodes.at(branch)
else:
msg[i_msg] = self.leaves[-(branch + 1)]
node = self.nodes.back()
i_msg += 1
if i_msg == n_msg:
break
i_byte += 1
else:
raise Exception("Buffer exhausted at %d/%d symbols read." % (i_msg, len(msg)))
property strings: property strings:
@cython.boundscheck(False) @cython.boundscheck(False)