* Tests passing on round-trip pack/unpack on basic example

This commit is contained in:
Matthew Honnibal 2015-07-17 21:20:48 +02:00
parent 44f39a876f
commit cf0c788892
7 changed files with 108 additions and 51 deletions

View File

@ -19,3 +19,5 @@ cdef class BitArray:
cdef uint32_t i
cdef int extend(self, uint64_t code, char n_bits) except -1
cdef uint32_t read32(self) except 0

View File

@ -1,4 +1,4 @@
from libc.string cimport memcpy
# Note that we're setting the most significant bits here first, when in practice
# we're actually wanting the last bit to be most significant (for Huffman coding,
@ -20,19 +20,62 @@ cdef class BitArray:
self.bit_of_byte = 0
self.i = 0
def __len__(self):
return 8 * len(self.data) + self.bit_of_byte
def __str__(self):
cdef uchar byte, i
cdef uchar one = 1
string = b''
for i in range(len(self.data)):
byte = ord(self.data[i])
for j in range(8):
string += b'1' if (byte & (one << j)) else b'0'
for i in range(self.bit_of_byte):
string += b'1' if (byte & (one << i)) else b'0'
return string
def seek(self, i):
self.i = i
def __iter__(self):
cdef uchar byte, i
cdef uchar one = 1
start_byte = self.i // 8
if (self.i % 8) != 0:
for i in range(self.i % 8):
yield 1 if (self.data[start_byte] & (one << i)) else 0
start_bit = self.i % 8
if start_bit != 0 and start_byte < len(self.data):
byte = ord(self.data[start_byte])
for i in range(start_bit, 8):
self.i += 1
yield 1 if (byte & (one << i)) else 0
start_byte += 1
start_bit = 0
for byte in self.data[start_byte:]:
for i in range(8):
self.i += 1
yield 1 if byte & (one << i) else 0
for i in range(self.bit_of_byte):
yield 1 if self.byte & (one << i) else 0
if self.bit_of_byte != 0:
byte = self.byte
for i in range(start_bit, self.bit_of_byte):
self.i += 1
yield 1 if self.byte & (one << i) else 0
cdef uint32_t read32(self) except 0:
cdef int start_byte = self.i // 8
# TODO portability
cdef uchar[4] chars
chars[0] = <uchar>ord(self.data[start_byte])
chars[1] = <uchar>ord(self.data[start_byte+1])
chars[2] = <uchar>ord(self.data[start_byte+2])
chars[3] = <uchar>ord(self.data[start_byte+3])
cdef uint32_t output
memcpy(&output, chars, 4)
self.i += 32
return output
def as_bytes(self):
if self.bit_of_byte != 0:
@ -47,6 +90,7 @@ cdef class BitArray:
else:
self.byte &= ~(one << self.bit_of_byte)
self.bit_of_byte += 1
self.i += 1
if self.bit_of_byte == 8:
self.data += chr(self.byte)
self.byte = 0
@ -65,5 +109,4 @@ cdef class BitArray:
self.data += chr(self.byte)
self.byte = 0
self.bit_of_byte = 0
self.i += 1

View File

@ -1,5 +1,7 @@
cimport cython
from ..typedefs cimport attr_t
from .bits cimport bit_append
from .bits cimport BitArray
@ -16,7 +18,6 @@ cdef class HuffmanCodec:
weights (float[:]): A descending-sorted sequence of probabilities/weights.
Must include a weight for an EOL symbol.
eol (uint32_t): The index of the weight of the EOL symbol.
"""
def __init__(self, float[:] weights):
self.codes.resize(len(weights))
@ -29,12 +30,12 @@ cdef class HuffmanCodec:
path.length = 0
assign_codes(self.nodes, self.codes, len(self.nodes) - 1, path)
def encode(self, uint32_t[:] msg, BitArray into_bits):
cdef uint32_t i
def encode(self, attr_t[:] msg, BitArray into_bits):
cdef int i
for i in range(len(msg)):
into_bits.extend(self.codes[msg[i]].bits, self.codes[msg[i]].length)
def decode(self, bits, uint32_t[:] into_msg):
def decode(self, bits, attr_t[:] into_msg):
node = self.nodes.back()
cdef int i = 0
cdef int n = len(into_msg)
@ -49,7 +50,8 @@ cdef class HuffmanCodec:
if i == n:
break
else:
raise Exception
raise Exception(
"Buffer exhausted at %d/%d symbols read." % (i, len(into_msg)))
property strings:
@cython.boundscheck(False)

View File

@ -2,5 +2,6 @@ from ..vocab cimport Vocab
cdef class Packer:
cdef tuple _codecs
cdef Vocab vocab
cdef readonly tuple attrs
cdef readonly tuple _codecs
cdef readonly Vocab vocab

View File

@ -7,7 +7,7 @@ from libcpp.pair cimport pair
from cymem.cymem cimport Address, Pool
from preshed.maps cimport PreshMap
from ..attrs cimport ID, ORTH, SPACY, TAG, HEAD, DEP, ENT_IOB, ENT_TYPE
from ..attrs cimport ID, SPACY, TAG, HEAD, DEP, ENT_IOB, ENT_TYPE
from ..tokens.doc cimport Doc
from ..vocab cimport Vocab
from ..typedefs cimport attr_t
@ -50,24 +50,28 @@ cdef class _BinaryCodec:
for i in range(len(msg)):
bits.append(msg[i])
def decode(self, bits, attr_t[:] msg):
for i in range(len(msg)):
msg[i] = bits.next()
def decode(self, BitArray bits, attr_t[:] msg):
cdef int i = 0
for bit in bits:
msg[i] = bit
i += 1
if i == len(msg):
break
cdef class _AttributeCodec:
cdef Pool mem
cdef attr_t* _keys
cdef PreshMap _map
cdef dict _map
cdef HuffmanCodec _codec
def __init__(self, freqs):
self.mem = Pool()
cdef uint64_t key
cdef uint64_t count
cdef pair[uint64_t, uint64_t] item
cdef attr_t key
cdef float count
cdef pair[float, attr_t] item
cdef priority_queue[pair[uint64_t, uint64_t]] items
cdef priority_queue[pair[float, attr_t]] items
for key, count in freqs:
item.first = count
@ -75,7 +79,7 @@ cdef class _AttributeCodec:
items.push(item)
weights = numpy.ndarray(shape=(len(freqs),), dtype=numpy.float32)
self._keys = <attr_t*>self.mem.alloc(len(freqs), sizeof(attr_t))
self._map = PreshMap()
self._map = {}
cdef int i = 0
while not items.empty():
item = items.top()
@ -88,8 +92,9 @@ cdef class _AttributeCodec:
self._codec = HuffmanCodec(weights)
def encode(self, attr_t[:] msg, BitArray dest):
cdef int i
for i in range(len(msg)):
msg[i] = <attr_t>self._map[msg[i]]
msg[i] = self._map[msg[i]]
self._codec.encode(msg, dest)
def decode(self, BitArray bits, attr_t[:] dest):
@ -103,30 +108,36 @@ cdef class Packer:
def __init__(self, Vocab vocab, list_of_attr_freqs):
self.vocab = vocab
codecs = []
self.attrs = []
attrs = []
for attr, freqs in list_of_attr_freqs:
if attr == ORTH:
if attr == ID:
codecs.append(make_vocab_codec(vocab))
elif attr == SPACY:
codecs.append(_BinaryCodec())
else:
codecs.append(_AttributeCodec(freqs))
self.attrs.append(attr)
attrs.append(attr)
self._codecs = tuple(codecs)
self.attrs = tuple(attrs)
def pack(self, Doc doc):
array = doc.to_array(self.attrs)
cdef BitArray bits = BitArray()
cdef uint32_t length = len(array)
bits.extend(length, 32)
cdef uint32_t length = 3
#cdef uint32_t length = len(doc)
#bits.extend(length, 32)
for i, codec in enumerate(self._codecs):
codec.encode(array[i], bits)
codec.encode(array[:, i], bits)
return bits
def unpack(self, bits):
cdef uint32_t length = bits.read(32)
array = numpy.ndarray(shape=(len(self.codecs), length), dtype=numpy.int)
for i, codec in enumerate(self.codecs):
array[i] = codec.decode(bits)
return Doc.from_array(self.vocab, self.attrs, array)
def unpack(self, BitArray bits):
bits.seek(0)
#cdef uint32_t length = bits.read32()
cdef uint32_t length = 3
array = numpy.zeros(shape=(length, len(self._codecs)), dtype=numpy.int32)
for i, codec in enumerate(self._codecs):
codec.decode(bits, array[:, i])
doc = Doc.from_ids(self.vocab, array[:, 0], array[:, 1])
doc.from_array(self.attrs, array)
return doc

View File

@ -97,17 +97,15 @@ cdef class Doc:
self._py_tokens = []
@classmethod
def from_orth(cls, Vocab vocab, attr_t[:] orths, attr_t[:] spaces):
def from_ids(cls, Vocab vocab, ids, spaces):
cdef int i
cdef const LexemeC* lex
cdef Doc self = cls(vocab)
cdef unicode string
cdef UniStr new_orth_c
for i in range(len(orths)):
string = vocab.strings[orths[i]]
slice_unicode(&new_orth_c, string, 0, len(string))
lex = self.vocab.get(self.mem, &new_orth_c)
self.push_back(lex, spaces[i])
cdef bint space = 0
for i in range(len(ids)):
lex = self.vocab.lexemes.at(ids[i])
space = spaces[i]
self.push_back(lex, space)
return self
def __getitem__(self, object i):
@ -229,11 +227,11 @@ cdef class Doc:
"""
cdef int i, j
cdef attr_id_t feature
cdef np.ndarray[long, ndim=2] output
cdef np.ndarray[attr_t, ndim=2] output
# Make an array from the attributes --- otherwise our inner loop is Python
# dict iteration.
cdef np.ndarray[long, ndim=1] attr_ids = numpy.asarray(py_attr_ids)
output = numpy.ndarray(shape=(self.length, len(attr_ids)), dtype=numpy.int)
cdef np.ndarray[attr_t, ndim=1] attr_ids = numpy.asarray(py_attr_ids, dtype=numpy.int32)
output = numpy.ndarray(shape=(self.length, len(attr_ids)), dtype=numpy.int32)
for i in range(self.length):
for j, feature in enumerate(attr_ids):
output[i, j] = get_token_attr(&self.data[i], feature)

View File

@ -1,10 +1,10 @@
from libc.stdint cimport uint16_t, uint32_t, uint64_t, uintptr_t
from libc.stdint cimport uint16_t, uint32_t, uint64_t, uintptr_t, int32_t
from libc.stdint cimport uint8_t
ctypedef uint64_t hash_t
ctypedef char* utf8_t
ctypedef uint32_t attr_t
ctypedef int32_t attr_t
ctypedef uint64_t flags_t
ctypedef uint32_t id_t
ctypedef uint16_t len_t