Modernise Huffman tests

This commit is contained in:
Ines Montani 2017-01-12 21:58:40 +01:00
parent edeeeccea5
commit 5dbc6e59f6

View File

@ -1,15 +1,15 @@
# coding: utf-8
from __future__ import unicode_literals from __future__ import unicode_literals
from __future__ import division from __future__ import division
import pytest from ...serialize.huffman import HuffmanCodec
from ...serialize.bits import BitArray
from spacy.serialize.huffman import HuffmanCodec
from spacy.serialize.bits import BitArray
import numpy
import math
from heapq import heappush, heappop, heapify from heapq import heappush, heappop, heapify
from collections import defaultdict from collections import defaultdict
import numpy
import pytest
def py_encode(symb2freq): def py_encode(symb2freq):
@ -29,7 +29,7 @@ def py_encode(symb2freq):
return dict(heappop(heap)[1:]) return dict(heappop(heap)[1:])
def test1(): def test_serialize_huffman_1():
probs = numpy.zeros(shape=(10,), dtype=numpy.float32) probs = numpy.zeros(shape=(10,), dtype=numpy.float32)
probs[0] = 0.3 probs[0] = 0.3
probs[1] = 0.2 probs[1] = 0.2
@ -43,43 +43,42 @@ def test1():
probs[9] = 0.000001 probs[9] = 0.000001
codec = HuffmanCodec(list(enumerate(probs))) codec = HuffmanCodec(list(enumerate(probs)))
py_codes = py_encode(dict(enumerate(probs))) py_codes = py_encode(dict(enumerate(probs)))
py_codes = list(py_codes.items()) py_codes = list(py_codes.items())
py_codes.sort() py_codes.sort()
assert codec.strings == [c for i, c in py_codes] assert codec.strings == [c for i, c in py_codes]
def test_empty(): def test_serialize_huffman_empty():
codec = HuffmanCodec({}) codec = HuffmanCodec({})
assert codec.strings == [] assert codec.strings == []
def test_round_trip(): def test_serialize_huffman_round_trip():
freqs = {'the': 10, 'quick': 3, 'brown': 4, 'fox': 1, 'jumped': 5, 'over': 8, words = ['the', 'quick', 'brown', 'fox', 'jumped', 'over', 'the', 'the',
'lazy': 1, 'dog': 2, '.': 9} 'lazy', 'dog', '.']
codec = HuffmanCodec(freqs.items()) freqs = {'the': 10, 'quick': 3, 'brown': 4, 'fox': 1, 'jumped': 5,
'over': 8, 'lazy': 1, 'dog': 2, '.': 9}
message = ['the', 'quick', 'brown', 'fox', 'jumped', 'over', 'the', codec = HuffmanCodec(freqs.items())
'the', 'lazy', 'dog', '.']
strings = list(codec.strings) strings = list(codec.strings)
codes = dict([(codec.leaves[i], strings[i]) for i in range(len(codec.leaves))]) codes = dict([(codec.leaves[i], strings[i]) for i in range(len(codec.leaves))])
bits = codec.encode(message) bits = codec.encode(words)
string = ''.join('{0:b}'.format(c).rjust(8, '0')[::-1] for c in bits.as_bytes()) string = ''.join('{0:b}'.format(c).rjust(8, '0')[::-1] for c in bits.as_bytes())
for word in message: for word in words:
code = codes[word] code = codes[word]
assert string[:len(code)] == code assert string[:len(code)] == code
string = string[len(code):] string = string[len(code):]
unpacked = [0] * len(message) unpacked = [0] * len(words)
bits.seek(0) bits.seek(0)
codec.decode(bits, unpacked) codec.decode(bits, unpacked)
assert message == unpacked assert words == unpacked
def test_rosetta(): def test_serialize_huffman_rosetta():
txt = u"this is an example for huffman encoding" text = "this is an example for huffman encoding"
symb2freq = defaultdict(int) symb2freq = defaultdict(int)
for ch in txt: for ch in text:
symb2freq[ch] += 1 symb2freq[ch] += 1
by_freq = list(symb2freq.items()) by_freq = list(symb2freq.items())
by_freq.sort(reverse=True, key=lambda item: item[1]) by_freq.sort(reverse=True, key=lambda item: item[1])
@ -101,7 +100,7 @@ def test_rosetta():
assert my_exp_len == py_exp_len assert my_exp_len == py_exp_len
@pytest.mark.slow @pytest.mark.models
def test_vocab(EN): def test_vocab(EN):
codec = HuffmanCodec([(w.orth, numpy.exp(w.prob)) for w in EN.vocab]) codec = HuffmanCodec([(w.orth, numpy.exp(w.prob)) for w in EN.vocab])
expected_length = 0 expected_length = 0