mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Modernise Huffman tests
This commit is contained in:
parent
edeeeccea5
commit
5dbc6e59f6
|
@ -1,15 +1,15 @@
|
||||||
|
# coding: utf-8
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
|
||||||
import pytest
|
from ...serialize.huffman import HuffmanCodec
|
||||||
|
from ...serialize.bits import BitArray
|
||||||
|
|
||||||
from spacy.serialize.huffman import HuffmanCodec
|
|
||||||
from spacy.serialize.bits import BitArray
|
|
||||||
import numpy
|
|
||||||
import math
|
|
||||||
|
|
||||||
from heapq import heappush, heappop, heapify
|
from heapq import heappush, heappop, heapify
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
import numpy
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
def py_encode(symb2freq):
|
def py_encode(symb2freq):
|
||||||
|
@ -29,7 +29,7 @@ def py_encode(symb2freq):
|
||||||
return dict(heappop(heap)[1:])
|
return dict(heappop(heap)[1:])
|
||||||
|
|
||||||
|
|
||||||
def test1():
|
def test_serialize_huffman_1():
|
||||||
probs = numpy.zeros(shape=(10,), dtype=numpy.float32)
|
probs = numpy.zeros(shape=(10,), dtype=numpy.float32)
|
||||||
probs[0] = 0.3
|
probs[0] = 0.3
|
||||||
probs[1] = 0.2
|
probs[1] = 0.2
|
||||||
|
@ -43,43 +43,42 @@ def test1():
|
||||||
probs[9] = 0.000001
|
probs[9] = 0.000001
|
||||||
|
|
||||||
codec = HuffmanCodec(list(enumerate(probs)))
|
codec = HuffmanCodec(list(enumerate(probs)))
|
||||||
|
|
||||||
py_codes = py_encode(dict(enumerate(probs)))
|
py_codes = py_encode(dict(enumerate(probs)))
|
||||||
py_codes = list(py_codes.items())
|
py_codes = list(py_codes.items())
|
||||||
py_codes.sort()
|
py_codes.sort()
|
||||||
assert codec.strings == [c for i, c in py_codes]
|
assert codec.strings == [c for i, c in py_codes]
|
||||||
|
|
||||||
|
|
||||||
def test_empty():
|
def test_serialize_huffman_empty():
|
||||||
codec = HuffmanCodec({})
|
codec = HuffmanCodec({})
|
||||||
assert codec.strings == []
|
assert codec.strings == []
|
||||||
|
|
||||||
|
|
||||||
def test_round_trip():
|
def test_serialize_huffman_round_trip():
|
||||||
freqs = {'the': 10, 'quick': 3, 'brown': 4, 'fox': 1, 'jumped': 5, 'over': 8,
|
words = ['the', 'quick', 'brown', 'fox', 'jumped', 'over', 'the', 'the',
|
||||||
'lazy': 1, 'dog': 2, '.': 9}
|
'lazy', 'dog', '.']
|
||||||
codec = HuffmanCodec(freqs.items())
|
freqs = {'the': 10, 'quick': 3, 'brown': 4, 'fox': 1, 'jumped': 5,
|
||||||
|
'over': 8, 'lazy': 1, 'dog': 2, '.': 9}
|
||||||
|
|
||||||
message = ['the', 'quick', 'brown', 'fox', 'jumped', 'over', 'the',
|
codec = HuffmanCodec(freqs.items())
|
||||||
'the', 'lazy', 'dog', '.']
|
|
||||||
strings = list(codec.strings)
|
strings = list(codec.strings)
|
||||||
codes = dict([(codec.leaves[i], strings[i]) for i in range(len(codec.leaves))])
|
codes = dict([(codec.leaves[i], strings[i]) for i in range(len(codec.leaves))])
|
||||||
bits = codec.encode(message)
|
bits = codec.encode(words)
|
||||||
string = ''.join('{0:b}'.format(c).rjust(8, '0')[::-1] for c in bits.as_bytes())
|
string = ''.join('{0:b}'.format(c).rjust(8, '0')[::-1] for c in bits.as_bytes())
|
||||||
for word in message:
|
for word in words:
|
||||||
code = codes[word]
|
code = codes[word]
|
||||||
assert string[:len(code)] == code
|
assert string[:len(code)] == code
|
||||||
string = string[len(code):]
|
string = string[len(code):]
|
||||||
unpacked = [0] * len(message)
|
unpacked = [0] * len(words)
|
||||||
bits.seek(0)
|
bits.seek(0)
|
||||||
codec.decode(bits, unpacked)
|
codec.decode(bits, unpacked)
|
||||||
assert message == unpacked
|
assert words == unpacked
|
||||||
|
|
||||||
|
|
||||||
def test_rosetta():
|
def test_serialize_huffman_rosetta():
|
||||||
txt = u"this is an example for huffman encoding"
|
text = "this is an example for huffman encoding"
|
||||||
symb2freq = defaultdict(int)
|
symb2freq = defaultdict(int)
|
||||||
for ch in txt:
|
for ch in text:
|
||||||
symb2freq[ch] += 1
|
symb2freq[ch] += 1
|
||||||
by_freq = list(symb2freq.items())
|
by_freq = list(symb2freq.items())
|
||||||
by_freq.sort(reverse=True, key=lambda item: item[1])
|
by_freq.sort(reverse=True, key=lambda item: item[1])
|
||||||
|
@ -101,7 +100,7 @@ def test_rosetta():
|
||||||
assert my_exp_len == py_exp_len
|
assert my_exp_len == py_exp_len
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.models
|
||||||
def test_vocab(EN):
|
def test_vocab(EN):
|
||||||
codec = HuffmanCodec([(w.orth, numpy.exp(w.prob)) for w in EN.vocab])
|
codec = HuffmanCodec([(w.orth, numpy.exp(w.prob)) for w in EN.vocab])
|
||||||
expected_length = 0
|
expected_length = 0
|
||||||
|
|
Loading…
Reference in New Issue
Block a user