From cef547a9f05fc39ee667606b6561ea3f106a7018 Mon Sep 17 00:00:00 2001 From: ines Date: Fri, 2 Jun 2017 18:18:30 +0200 Subject: [PATCH] Add serialization tests for tensorizer --- .../serialize/test_serialize_tensorizer.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 spacy/tests/serialize/test_serialize_tensorizer.py diff --git a/spacy/tests/serialize/test_serialize_tensorizer.py b/spacy/tests/serialize/test_serialize_tensorizer.py new file mode 100644 index 000000000..ba01a2fa6 --- /dev/null +++ b/spacy/tests/serialize/test_serialize_tensorizer.py @@ -0,0 +1,25 @@ +# coding: utf-8 +from __future__ import unicode_literals + +from ..util import make_tempdir +from ...pipeline import TokenVectorEncoder as Tensorizer + +import pytest + + +def test_serialize_tensorizer_roundtrip_bytes(en_vocab): + tensorizer = Tensorizer(en_vocab) + tensorizer.model = tensorizer.Model() + tensorizer_b = tensorizer.to_bytes() + new_tensorizer = Tensorizer(en_vocab).from_bytes(tensorizer_b) + assert new_tensorizer.to_bytes() == tensorizer_b + + +def test_serialize_tensorizer_roundtrip_disk(en_vocab): + tensorizer = Tensorizer(en_vocab) + tensorizer.model = tensorizer.Model() + with make_tempdir() as d: + file_path = d / 'tensorizer' + tensorizer.to_disk(file_path) + tensorizer_d = Tensorizer(en_vocab).from_disk(file_path) + assert tensorizer.to_bytes() == tensorizer_d.to_bytes()