* Update tests for huffman codec

This commit is contained in:
Matthew Honnibal 2015-07-19 17:59:51 +02:00
parent b8086067d5
commit 879ef9fa3e

View File

@ -6,35 +6,12 @@ import pytest
from spacy.serialize.huffman import HuffmanCodec from spacy.serialize.huffman import HuffmanCodec
from spacy.serialize.bits import BitArray from spacy.serialize.bits import BitArray
import numpy import numpy
import math
from heapq import heappush, heappop, heapify from heapq import heappush, heappop, heapify
from collections import defaultdict from collections import defaultdict
class MockPacker(object):
def __init__(self, freqs):
total = sum(freqs.values())
by_freq = freqs.items()
by_freq.sort(key=lambda item: item[1], reverse=True)
self.symbols = [sym for sym, freq in by_freq]
self.probs = numpy.array([item[1] / total for item in by_freq], dtype=numpy.float32)
self.table = {sym: i for i, sym in enumerate(self.symbols)}
self.codec = HuffmanCodec(self.probs)
def pack(self, message):
seq = [self.table[sym] for sym in message]
msg = numpy.array(seq, dtype=numpy.int32)
bits = BitArray()
self.codec.encode(msg, bits)
return bits
def unpack(self, bits, n):
msg = numpy.array(range(n), dtype=numpy.int32)
bits.seek(0)
self.codec.decode(bits, msg)
return [self.symbols[i] for i in msg]
def py_encode(symb2freq): def py_encode(symb2freq):
"""Huffman encode the given dict mapping symbols to weights """Huffman encode the given dict mapping symbols to weights
From Rosetta Code From Rosetta Code
@ -65,7 +42,7 @@ def test1():
probs[8] = 0.0001 probs[8] = 0.0001
probs[9] = 0.000001 probs[9] = 0.000001
codec = HuffmanCodec(probs) codec = HuffmanCodec(list(enumerate(probs)))
py_codes = py_encode(dict(enumerate(probs))) py_codes = py_encode(dict(enumerate(probs)))
py_codes = py_codes.items() py_codes = py_codes.items()
@ -76,19 +53,21 @@ def test1():
def test_round_trip(): def test_round_trip():
freqs = {'the': 10, 'quick': 3, 'brown': 4, 'fox': 1, 'jumped': 5, 'over': 8, freqs = {'the': 10, 'quick': 3, 'brown': 4, 'fox': 1, 'jumped': 5, 'over': 8,
'lazy': 1, 'dog': 2, '.': 9} 'lazy': 1, 'dog': 2, '.': 9}
packer = MockPacker(freqs) codec = HuffmanCodec(freqs.items())
message = ['the', 'quick', 'brown', 'fox', 'jumped', 'over', 'the', message = ['the', 'quick', 'brown', 'fox', 'jumped', 'over', 'the',
'the', 'lazy', 'dog', '.'] 'the', 'lazy', 'dog', '.']
strings = list(packer.codec.strings) strings = list(codec.strings)
codes = {packer.symbols[i]: strings[i] for i in range(len(packer.symbols))} codes = {codec.leaves[i]: strings[i] for i in range(len(codec.leaves))}
bits = packer.pack(message) bits = codec.encode(message)
string = b''.join(b'{0:b}'.format(ord(c)).rjust(8, b'0')[::-1] for c in bits.as_bytes()) string = b''.join(b'{0:b}'.format(ord(c)).rjust(8, b'0')[::-1] for c in bits.as_bytes())
for word in message: for word in message:
code = codes[word] code = codes[word]
assert string[:len(code)] == code assert string[:len(code)] == code
string = string[len(code):] string = string[len(code):]
unpacked = packer.unpack(bits, len(message)) unpacked = [0] * len(message)
bits.seek(0)
codec.decode(bits, unpacked)
assert message == unpacked assert message == unpacked
@ -100,28 +79,29 @@ def test_rosetta():
by_freq = symb2freq.items() by_freq = symb2freq.items()
by_freq.sort(reverse=True, key=lambda item: item[1]) by_freq.sort(reverse=True, key=lambda item: item[1])
symbols = [sym for sym, prob in by_freq] symbols = [sym for sym, prob in by_freq]
probs = numpy.array([prob for sym, prob in by_freq], dtype=numpy.float32)
codec = HuffmanCodec(probs) codec = HuffmanCodec(symb2freq.items())
py_codec = py_encode(symb2freq) py_codec = py_encode(symb2freq)
codes = {codec.leaves[i]: codec.strings[i] for i in range(len(codec.leaves))}
my_lengths = defaultdict(int) my_lengths = defaultdict(int)
py_lengths = defaultdict(int) py_lengths = defaultdict(int)
for i, my in enumerate(codec.strings): for symb, freq in symb2freq.items():
symb = by_freq[i][0] my = codes[symb]
my_lengths[len(my)] += by_freq[i][1] my_lengths[len(my)] += freq
py_lengths[len(py_codec[symb])] += by_freq[i][1] py_lengths[len(py_codec[symb])] += freq
my_exp_len = sum(length * weight for length, weight in my_lengths.items()) my_exp_len = sum(length * weight for length, weight in my_lengths.items())
py_exp_len = sum(length * weight for length, weight in py_lengths.items()) py_exp_len = sum(length * weight for length, weight in py_lengths.items())
assert my_exp_len == py_exp_len assert my_exp_len == py_exp_len
"""
def test_vocab(EN): def test_vocab(EN):
codec = EN.vocab.codec codec = HuffmanCodec([(w.orth, numpy.exp(w.prob)) for w in EN.vocab])
expected_length = 0 expected_length = 0
for i, code in enumerate(codec.strings): for i, code in enumerate(codec.strings):
expected_length += len(code) * numpy.exp(EN.vocab[i].prob) leaf = codec.leaves[i]
expected_length += len(code) * numpy.exp(EN.vocab[leaf].prob)
assert 8 < expected_length < 15 assert 8 < expected_length < 15
@ -134,12 +114,10 @@ def test_freqs():
continue continue
freq, word = pieces freq, word = pieces
freqs.append(int(freq)) freqs.append(int(freq))
freqs.append(1) words.append(word)
total = sum(freqs) total = float(sum(freqs))
freqs = [(float(f) / total) for f in freqs] codec = HuffmanCodec(zip(words, freqs))
codec = HuffmanCodec(numpy.array(freqs, dtype=numpy.float32), len(freqs)-1)
expected_length = 0 expected_length = 0
for i, code in enumerate(codec.strings): for i, code in enumerate(codec.strings):
expected_length += len(code) * freqs[i] expected_length += len(code) * (freqs[i] / total)
assert 8 < expected_length < 14 assert 8 < expected_length < 14
"""