mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-25 00:34:20 +03:00
Add serialization tests for tagger
This commit is contained in:
parent
1b593bbd6d
commit
43b4d63f85
39
spacy/tests/serialize/test_serialize_tagger.py
Normal file
39
spacy/tests/serialize/test_serialize_tagger.py
Normal file
|
@ -0,0 +1,39 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from ..util import make_tempdir
|
||||
from ...pipeline import NeuralTagger as Tagger
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def taggers(en_vocab):
|
||||
tagger1 = Tagger(en_vocab, True)
|
||||
tagger2 = Tagger(en_vocab, True)
|
||||
tagger1.model = tagger1.Model(None, None)
|
||||
tagger2.model = tagger2.Model(None, None)
|
||||
return (tagger1, tagger2)
|
||||
|
||||
|
||||
def test_serialize_tagger_roundtrip_bytes(en_vocab, taggers):
|
||||
tagger1, tagger2 = taggers
|
||||
tagger1_b = tagger1.to_bytes()
|
||||
tagger2_b = tagger2.to_bytes()
|
||||
assert tagger1_b == tagger2_b
|
||||
tagger1 = tagger1.from_bytes(tagger1_b)
|
||||
assert tagger1.to_bytes() == tagger1_b
|
||||
new_tagger1 = Tagger(en_vocab).from_bytes(tagger1_b)
|
||||
assert new_tagger1.to_bytes() == tagger1_b
|
||||
|
||||
|
||||
def test_serialize_tagger_roundtrip_disk(en_vocab, taggers):
|
||||
tagger1, tagger2 = taggers
|
||||
with make_tempdir() as d:
|
||||
file_path1 = d / 'tagger1'
|
||||
file_path2 = d / 'tagger2'
|
||||
tagger1.to_disk(file_path1)
|
||||
tagger2.to_disk(file_path2)
|
||||
tagger1_d = Tagger(en_vocab).from_disk(file_path1)
|
||||
tagger2_d = Tagger(en_vocab).from_disk(file_path2)
|
||||
assert tagger1_d.to_bytes() == tagger2_d.to_bytes()
|
Loading…
Reference in New Issue
Block a user