mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-24 00:04:15 +03:00
* Fix tests
This commit is contained in:
parent
68374149ae
commit
f7f0ad1a78
|
@ -16,7 +16,8 @@ def test_binary():
|
||||||
msg = numpy.array([0, 1, 0, 1, 1], numpy.int32)
|
msg = numpy.array([0, 1, 0, 1, 1], numpy.int32)
|
||||||
codec.encode(msg, bits)
|
codec.encode(msg, bits)
|
||||||
result = numpy.array([0, 0, 0, 0, 0], numpy.int32)
|
result = numpy.array([0, 0, 0, 0, 0], numpy.int32)
|
||||||
codec.decode(iter(bits), result)
|
bits.seek(0)
|
||||||
|
codec.decode(bits, result)
|
||||||
assert list(msg) == list(result)
|
assert list(msg) == list(result)
|
||||||
|
|
||||||
|
|
||||||
|
@ -35,6 +36,7 @@ def test_attribute():
|
||||||
msg_list = list(msg)
|
msg_list = list(msg)
|
||||||
codec.encode(msg, bits)
|
codec.encode(msg, bits)
|
||||||
result = numpy.array([0, 0], dtype=numpy.int32)
|
result = numpy.array([0, 0], dtype=numpy.int32)
|
||||||
|
bits.seek(0)
|
||||||
codec.decode(bits, result)
|
codec.decode(bits, result)
|
||||||
assert msg_list == list(result)
|
assert msg_list == list(result)
|
||||||
|
|
||||||
|
@ -69,5 +71,6 @@ def test_vocab_codec():
|
||||||
msg_list = list(msg)
|
msg_list = list(msg)
|
||||||
codec.encode(msg, bits)
|
codec.encode(msg, bits)
|
||||||
result = numpy.array(range(len(msg)), dtype=numpy.int32)
|
result = numpy.array(range(len(msg)), dtype=numpy.int32)
|
||||||
|
bits.seek(0)
|
||||||
codec.decode(bits, result)
|
codec.decode(bits, result)
|
||||||
assert msg_list == list(result)
|
assert msg_list == list(result)
|
||||||
|
|
|
@ -13,7 +13,6 @@ from collections import defaultdict
|
||||||
|
|
||||||
class MockPacker(object):
|
class MockPacker(object):
|
||||||
def __init__(self, freqs):
|
def __init__(self, freqs):
|
||||||
freqs['-eol-'] = 5
|
|
||||||
total = sum(freqs.values())
|
total = sum(freqs.values())
|
||||||
by_freq = freqs.items()
|
by_freq = freqs.items()
|
||||||
by_freq.sort(key=lambda item: item[1], reverse=True)
|
by_freq.sort(key=lambda item: item[1], reverse=True)
|
||||||
|
@ -24,13 +23,14 @@ class MockPacker(object):
|
||||||
|
|
||||||
def pack(self, message):
|
def pack(self, message):
|
||||||
seq = [self.table[sym] for sym in message]
|
seq = [self.table[sym] for sym in message]
|
||||||
msg = numpy.array(seq, dtype=numpy.uint32)
|
msg = numpy.array(seq, dtype=numpy.int32)
|
||||||
bits = BitArray()
|
bits = BitArray()
|
||||||
self.codec.encode(msg, bits)
|
self.codec.encode(msg, bits)
|
||||||
return bits
|
return bits
|
||||||
|
|
||||||
def unpack(self, bits, n):
|
def unpack(self, bits, n):
|
||||||
msg = numpy.array(range(n), dtype=numpy.uint32)
|
msg = numpy.array(range(n), dtype=numpy.int32)
|
||||||
|
bits.seek(0)
|
||||||
self.codec.decode(bits, msg)
|
self.codec.decode(bits, msg)
|
||||||
return [self.symbols[i] for i in msg]
|
return [self.symbols[i] for i in msg]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user