From 1c218397f69beeb895f357662d718d627842a251 Mon Sep 17 00:00:00 2001 From: ines Date: Thu, 9 Nov 2017 02:29:03 +0100 Subject: [PATCH] Ensure path in Doc.to_disk/from_disk (resolves ##1521) Also add Doc serialization tests with both Path and string path options --- spacy/tests/serialize/test_serialize_doc.py | 34 +++++++++++++++++++++ spacy/tokens/doc.pyx | 2 ++ 2 files changed, 36 insertions(+) create mode 100644 spacy/tests/serialize/test_serialize_doc.py diff --git a/spacy/tests/serialize/test_serialize_doc.py b/spacy/tests/serialize/test_serialize_doc.py new file mode 100644 index 000000000..5a10b656a --- /dev/null +++ b/spacy/tests/serialize/test_serialize_doc.py @@ -0,0 +1,34 @@ +# coding: utf-8 +from __future__ import unicode_literals + +from ..util import make_tempdir, get_doc +from ...tokens import Doc +from ...compat import path2str + +import pytest + + +def test_serialize_doc_roundtrip_bytes(en_vocab): + doc = get_doc(en_vocab, words=['hello', 'world']) + doc_b = doc.to_bytes() + new_doc = Doc(en_vocab).from_bytes(doc_b) + assert new_doc.to_bytes() == doc_b + + +def test_serialize_doc_roundtrip_disk(en_vocab): + doc = get_doc(en_vocab, words=['hello', 'world']) + with make_tempdir() as d: + file_path = d / 'doc' + doc.to_disk(file_path) + doc_d = Doc(en_vocab).from_disk(file_path) + assert doc.to_bytes() == doc_d.to_bytes() + + +def test_serialize_doc_roundtrip_disk_str_path(en_vocab): + doc = get_doc(en_vocab, words=['hello', 'world']) + with make_tempdir() as d: + file_path = d / 'doc' + file_path = path2str(file_path) + doc.to_disk(file_path) + doc_d = Doc(en_vocab).from_disk(file_path) + assert doc.to_bytes() == doc_d.to_bytes() diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index eef25c712..68617bb5e 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -737,6 +737,7 @@ cdef class Doc: path (unicode or Path): A path to a directory, which will be created if it doesn't exist. Paths may be either strings or Path-like objects. """ + path = util.ensure_path(path) with path.open('wb') as file_: file_.write(self.to_bytes(**exclude)) @@ -748,6 +749,7 @@ cdef class Doc: strings or `Path`-like objects. RETURNS (Doc): The modified `Doc` object. """ + path = util.ensure_path(path) with path.open('rb') as file_: bytes_data = file_.read() return self.from_bytes(bytes_data, **exclude)