Add serialization tests for tagger

This commit is contained in:
ines 2017-06-02 17:29:34 +02:00
parent 1b593bbd6d
commit 43b4d63f85

View File

@ -0,0 +1,39 @@
# coding: utf-8
from __future__ import unicode_literals
from ..util import make_tempdir
from ...pipeline import NeuralTagger as Tagger
import pytest
@pytest.fixture
def taggers(en_vocab):
tagger1 = Tagger(en_vocab, True)
tagger2 = Tagger(en_vocab, True)
tagger1.model = tagger1.Model(None, None)
tagger2.model = tagger2.Model(None, None)
return (tagger1, tagger2)
def test_serialize_tagger_roundtrip_bytes(en_vocab, taggers):
tagger1, tagger2 = taggers
tagger1_b = tagger1.to_bytes()
tagger2_b = tagger2.to_bytes()
assert tagger1_b == tagger2_b
tagger1 = tagger1.from_bytes(tagger1_b)
assert tagger1.to_bytes() == tagger1_b
new_tagger1 = Tagger(en_vocab).from_bytes(tagger1_b)
assert new_tagger1.to_bytes() == tagger1_b
def test_serialize_tagger_roundtrip_disk(en_vocab, taggers):
tagger1, tagger2 = taggers
with make_tempdir() as d:
file_path1 = d / 'tagger1'
file_path2 = d / 'tagger2'
tagger1.to_disk(file_path1)
tagger2.to_disk(file_path2)
tagger1_d = Tagger(en_vocab).from_disk(file_path1)
tagger2_d = Tagger(en_vocab).from_disk(file_path2)
assert tagger1_d.to_bytes() == tagger2_d.to_bytes()