* Fix tests

This commit is contained in:
Matthew Honnibal 2015-07-17 21:31:44 +02:00
parent 68374149ae
commit f7f0ad1a78
2 changed files with 7 additions and 4 deletions

View File

@ -16,7 +16,8 @@ def test_binary():
msg = numpy.array([0, 1, 0, 1, 1], numpy.int32)
codec.encode(msg, bits)
result = numpy.array([0, 0, 0, 0, 0], numpy.int32)
codec.decode(iter(bits), result)
bits.seek(0)
codec.decode(bits, result)
assert list(msg) == list(result)
@ -35,6 +36,7 @@ def test_attribute():
msg_list = list(msg)
codec.encode(msg, bits)
result = numpy.array([0, 0], dtype=numpy.int32)
bits.seek(0)
codec.decode(bits, result)
assert msg_list == list(result)
@ -69,5 +71,6 @@ def test_vocab_codec():
msg_list = list(msg)
codec.encode(msg, bits)
result = numpy.array(range(len(msg)), dtype=numpy.int32)
bits.seek(0)
codec.decode(bits, result)
assert msg_list == list(result)

View File

@ -13,7 +13,6 @@ from collections import defaultdict
class MockPacker(object):
def __init__(self, freqs):
freqs['-eol-'] = 5
total = sum(freqs.values())
by_freq = freqs.items()
by_freq.sort(key=lambda item: item[1], reverse=True)
@ -24,13 +23,14 @@ class MockPacker(object):
def pack(self, message):
seq = [self.table[sym] for sym in message]
msg = numpy.array(seq, dtype=numpy.uint32)
msg = numpy.array(seq, dtype=numpy.int32)
bits = BitArray()
self.codec.encode(msg, bits)
return bits
def unpack(self, bits, n):
msg = numpy.array(range(n), dtype=numpy.uint32)
msg = numpy.array(range(n), dtype=numpy.int32)
bits.seek(0)
self.codec.decode(bits, msg)
return [self.symbols[i] for i in msg]