* Tests passing on round-trip pack/unpack on basic example

This commit is contained in:
Matthew Honnibal 2015-07-17 21:20:48 +02:00
parent 44f39a876f
commit cf0c788892
7 changed files with 108 additions and 51 deletions

View File

@ -19,3 +19,5 @@ cdef class BitArray:
cdef uint32_t i cdef uint32_t i
cdef int extend(self, uint64_t code, char n_bits) except -1 cdef int extend(self, uint64_t code, char n_bits) except -1
cdef uint32_t read32(self) except 0

View File

@ -1,4 +1,4 @@
from libc.string cimport memcpy
# Note that we're setting the most significant bits here first, when in practice # Note that we're setting the most significant bits here first, when in practice
# we're actually wanting the last bit to be most significant (for Huffman coding, # we're actually wanting the last bit to be most significant (for Huffman coding,
@ -20,20 +20,63 @@ cdef class BitArray:
self.bit_of_byte = 0 self.bit_of_byte = 0
self.i = 0 self.i = 0
def __len__(self):
return 8 * len(self.data) + self.bit_of_byte
def __str__(self):
cdef uchar byte, i
cdef uchar one = 1
string = b''
for i in range(len(self.data)):
byte = ord(self.data[i])
for j in range(8):
string += b'1' if (byte & (one << j)) else b'0'
for i in range(self.bit_of_byte):
string += b'1' if (byte & (one << i)) else b'0'
return string
def seek(self, i):
self.i = i
def __iter__(self): def __iter__(self):
cdef uchar byte, i cdef uchar byte, i
cdef uchar one = 1 cdef uchar one = 1
start_byte = self.i // 8 start_byte = self.i // 8
if (self.i % 8) != 0: start_bit = self.i % 8
for i in range(self.i % 8):
yield 1 if (self.data[start_byte] & (one << i)) else 0 if start_bit != 0 and start_byte < len(self.data):
byte = ord(self.data[start_byte])
for i in range(start_bit, 8):
self.i += 1
yield 1 if (byte & (one << i)) else 0
start_byte += 1 start_byte += 1
start_bit = 0
for byte in self.data[start_byte:]: for byte in self.data[start_byte:]:
for i in range(8): for i in range(8):
self.i += 1
yield 1 if byte & (one << i) else 0 yield 1 if byte & (one << i) else 0
for i in range(self.bit_of_byte):
if self.bit_of_byte != 0:
byte = self.byte
for i in range(start_bit, self.bit_of_byte):
self.i += 1
yield 1 if self.byte & (one << i) else 0 yield 1 if self.byte & (one << i) else 0
cdef uint32_t read32(self) except 0:
cdef int start_byte = self.i // 8
# TODO portability
cdef uchar[4] chars
chars[0] = <uchar>ord(self.data[start_byte])
chars[1] = <uchar>ord(self.data[start_byte+1])
chars[2] = <uchar>ord(self.data[start_byte+2])
chars[3] = <uchar>ord(self.data[start_byte+3])
cdef uint32_t output
memcpy(&output, chars, 4)
self.i += 32
return output
def as_bytes(self): def as_bytes(self):
if self.bit_of_byte != 0: if self.bit_of_byte != 0:
return self.data + chr(self.byte) return self.data + chr(self.byte)
@ -47,6 +90,7 @@ cdef class BitArray:
else: else:
self.byte &= ~(one << self.bit_of_byte) self.byte &= ~(one << self.bit_of_byte)
self.bit_of_byte += 1 self.bit_of_byte += 1
self.i += 1
if self.bit_of_byte == 8: if self.bit_of_byte == 8:
self.data += chr(self.byte) self.data += chr(self.byte)
self.byte = 0 self.byte = 0
@ -65,5 +109,4 @@ cdef class BitArray:
self.data += chr(self.byte) self.data += chr(self.byte)
self.byte = 0 self.byte = 0
self.bit_of_byte = 0 self.bit_of_byte = 0
self.i += 1

View File

@ -1,5 +1,7 @@
cimport cython cimport cython
from ..typedefs cimport attr_t
from .bits cimport bit_append from .bits cimport bit_append
from .bits cimport BitArray from .bits cimport BitArray
@ -16,7 +18,6 @@ cdef class HuffmanCodec:
weights (float[:]): A descending-sorted sequence of probabilities/weights. weights (float[:]): A descending-sorted sequence of probabilities/weights.
Must include a weight for an EOL symbol. Must include a weight for an EOL symbol.
eol (uint32_t): The index of the weight of the EOL symbol.
""" """
def __init__(self, float[:] weights): def __init__(self, float[:] weights):
self.codes.resize(len(weights)) self.codes.resize(len(weights))
@ -29,12 +30,12 @@ cdef class HuffmanCodec:
path.length = 0 path.length = 0
assign_codes(self.nodes, self.codes, len(self.nodes) - 1, path) assign_codes(self.nodes, self.codes, len(self.nodes) - 1, path)
def encode(self, uint32_t[:] msg, BitArray into_bits): def encode(self, attr_t[:] msg, BitArray into_bits):
cdef uint32_t i cdef int i
for i in range(len(msg)): for i in range(len(msg)):
into_bits.extend(self.codes[msg[i]].bits, self.codes[msg[i]].length) into_bits.extend(self.codes[msg[i]].bits, self.codes[msg[i]].length)
def decode(self, bits, uint32_t[:] into_msg): def decode(self, bits, attr_t[:] into_msg):
node = self.nodes.back() node = self.nodes.back()
cdef int i = 0 cdef int i = 0
cdef int n = len(into_msg) cdef int n = len(into_msg)
@ -49,7 +50,8 @@ cdef class HuffmanCodec:
if i == n: if i == n:
break break
else: else:
raise Exception raise Exception(
"Buffer exhausted at %d/%d symbols read." % (i, len(into_msg)))
property strings: property strings:
@cython.boundscheck(False) @cython.boundscheck(False)

View File

@ -2,5 +2,6 @@ from ..vocab cimport Vocab
cdef class Packer: cdef class Packer:
cdef tuple _codecs cdef readonly tuple attrs
cdef Vocab vocab cdef readonly tuple _codecs
cdef readonly Vocab vocab

View File

@ -7,7 +7,7 @@ from libcpp.pair cimport pair
from cymem.cymem cimport Address, Pool from cymem.cymem cimport Address, Pool
from preshed.maps cimport PreshMap from preshed.maps cimport PreshMap
from ..attrs cimport ID, ORTH, SPACY, TAG, HEAD, DEP, ENT_IOB, ENT_TYPE from ..attrs cimport ID, SPACY, TAG, HEAD, DEP, ENT_IOB, ENT_TYPE
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
from ..vocab cimport Vocab from ..vocab cimport Vocab
from ..typedefs cimport attr_t from ..typedefs cimport attr_t
@ -50,24 +50,28 @@ cdef class _BinaryCodec:
for i in range(len(msg)): for i in range(len(msg)):
bits.append(msg[i]) bits.append(msg[i])
def decode(self, bits, attr_t[:] msg): def decode(self, BitArray bits, attr_t[:] msg):
for i in range(len(msg)): cdef int i = 0
msg[i] = bits.next() for bit in bits:
msg[i] = bit
i += 1
if i == len(msg):
break
cdef class _AttributeCodec: cdef class _AttributeCodec:
cdef Pool mem cdef Pool mem
cdef attr_t* _keys cdef attr_t* _keys
cdef PreshMap _map cdef dict _map
cdef HuffmanCodec _codec cdef HuffmanCodec _codec
def __init__(self, freqs): def __init__(self, freqs):
self.mem = Pool() self.mem = Pool()
cdef uint64_t key cdef attr_t key
cdef uint64_t count cdef float count
cdef pair[uint64_t, uint64_t] item cdef pair[float, attr_t] item
cdef priority_queue[pair[uint64_t, uint64_t]] items cdef priority_queue[pair[float, attr_t]] items
for key, count in freqs: for key, count in freqs:
item.first = count item.first = count
@ -75,7 +79,7 @@ cdef class _AttributeCodec:
items.push(item) items.push(item)
weights = numpy.ndarray(shape=(len(freqs),), dtype=numpy.float32) weights = numpy.ndarray(shape=(len(freqs),), dtype=numpy.float32)
self._keys = <attr_t*>self.mem.alloc(len(freqs), sizeof(attr_t)) self._keys = <attr_t*>self.mem.alloc(len(freqs), sizeof(attr_t))
self._map = PreshMap() self._map = {}
cdef int i = 0 cdef int i = 0
while not items.empty(): while not items.empty():
item = items.top() item = items.top()
@ -88,8 +92,9 @@ cdef class _AttributeCodec:
self._codec = HuffmanCodec(weights) self._codec = HuffmanCodec(weights)
def encode(self, attr_t[:] msg, BitArray dest): def encode(self, attr_t[:] msg, BitArray dest):
cdef int i
for i in range(len(msg)): for i in range(len(msg)):
msg[i] = <attr_t>self._map[msg[i]] msg[i] = self._map[msg[i]]
self._codec.encode(msg, dest) self._codec.encode(msg, dest)
def decode(self, BitArray bits, attr_t[:] dest): def decode(self, BitArray bits, attr_t[:] dest):
@ -103,30 +108,36 @@ cdef class Packer:
def __init__(self, Vocab vocab, list_of_attr_freqs): def __init__(self, Vocab vocab, list_of_attr_freqs):
self.vocab = vocab self.vocab = vocab
codecs = [] codecs = []
self.attrs = [] attrs = []
for attr, freqs in list_of_attr_freqs: for attr, freqs in list_of_attr_freqs:
if attr == ORTH: if attr == ID:
codecs.append(make_vocab_codec(vocab)) codecs.append(make_vocab_codec(vocab))
elif attr == SPACY: elif attr == SPACY:
codecs.append(_BinaryCodec()) codecs.append(_BinaryCodec())
else: else:
codecs.append(_AttributeCodec(freqs)) codecs.append(_AttributeCodec(freqs))
self.attrs.append(attr) attrs.append(attr)
self._codecs = tuple(codecs) self._codecs = tuple(codecs)
self.attrs = tuple(attrs)
def pack(self, Doc doc): def pack(self, Doc doc):
array = doc.to_array(self.attrs) array = doc.to_array(self.attrs)
cdef BitArray bits = BitArray() cdef BitArray bits = BitArray()
cdef uint32_t length = len(array) cdef uint32_t length = 3
bits.extend(length, 32) #cdef uint32_t length = len(doc)
#bits.extend(length, 32)
for i, codec in enumerate(self._codecs): for i, codec in enumerate(self._codecs):
codec.encode(array[i], bits) codec.encode(array[:, i], bits)
return bits return bits
def unpack(self, bits): def unpack(self, BitArray bits):
cdef uint32_t length = bits.read(32) bits.seek(0)
array = numpy.ndarray(shape=(len(self.codecs), length), dtype=numpy.int) #cdef uint32_t length = bits.read32()
for i, codec in enumerate(self.codecs): cdef uint32_t length = 3
array[i] = codec.decode(bits) array = numpy.zeros(shape=(length, len(self._codecs)), dtype=numpy.int32)
return Doc.from_array(self.vocab, self.attrs, array) for i, codec in enumerate(self._codecs):
codec.decode(bits, array[:, i])
doc = Doc.from_ids(self.vocab, array[:, 0], array[:, 1])
doc.from_array(self.attrs, array)
return doc

View File

@ -97,17 +97,15 @@ cdef class Doc:
self._py_tokens = [] self._py_tokens = []
@classmethod @classmethod
def from_orth(cls, Vocab vocab, attr_t[:] orths, attr_t[:] spaces): def from_ids(cls, Vocab vocab, ids, spaces):
cdef int i cdef int i
cdef const LexemeC* lex cdef const LexemeC* lex
cdef Doc self = cls(vocab) cdef Doc self = cls(vocab)
cdef unicode string cdef bint space = 0
cdef UniStr new_orth_c for i in range(len(ids)):
for i in range(len(orths)): lex = self.vocab.lexemes.at(ids[i])
string = vocab.strings[orths[i]] space = spaces[i]
slice_unicode(&new_orth_c, string, 0, len(string)) self.push_back(lex, space)
lex = self.vocab.get(self.mem, &new_orth_c)
self.push_back(lex, spaces[i])
return self return self
def __getitem__(self, object i): def __getitem__(self, object i):
@ -229,11 +227,11 @@ cdef class Doc:
""" """
cdef int i, j cdef int i, j
cdef attr_id_t feature cdef attr_id_t feature
cdef np.ndarray[long, ndim=2] output cdef np.ndarray[attr_t, ndim=2] output
# Make an array from the attributes --- otherwise our inner loop is Python # Make an array from the attributes --- otherwise our inner loop is Python
# dict iteration. # dict iteration.
cdef np.ndarray[long, ndim=1] attr_ids = numpy.asarray(py_attr_ids) cdef np.ndarray[attr_t, ndim=1] attr_ids = numpy.asarray(py_attr_ids, dtype=numpy.int32)
output = numpy.ndarray(shape=(self.length, len(attr_ids)), dtype=numpy.int) output = numpy.ndarray(shape=(self.length, len(attr_ids)), dtype=numpy.int32)
for i in range(self.length): for i in range(self.length):
for j, feature in enumerate(attr_ids): for j, feature in enumerate(attr_ids):
output[i, j] = get_token_attr(&self.data[i], feature) output[i, j] = get_token_attr(&self.data[i], feature)

View File

@ -1,10 +1,10 @@
from libc.stdint cimport uint16_t, uint32_t, uint64_t, uintptr_t from libc.stdint cimport uint16_t, uint32_t, uint64_t, uintptr_t, int32_t
from libc.stdint cimport uint8_t from libc.stdint cimport uint8_t
ctypedef uint64_t hash_t ctypedef uint64_t hash_t
ctypedef char* utf8_t ctypedef char* utf8_t
ctypedef uint32_t attr_t ctypedef int32_t attr_t
ctypedef uint64_t flags_t ctypedef uint64_t flags_t
ctypedef uint32_t id_t ctypedef uint32_t id_t
ctypedef uint16_t len_t ctypedef uint16_t len_t