Ensure files opened in from_disk are closed

Fixes [issue 1706](https://github.com/explosion/spaCy/issues/1706).
This commit is contained in:
Claudiu-Vlad Ursache 2018-02-13 20:44:33 +01:00
parent cdd4b3d05c
commit e28de12cbd
No known key found for this signature in database
GPG Key ID: 9A3505F5EA386896
5 changed files with 36 additions and 5 deletions

View File

@ -624,7 +624,7 @@ class Language(object):
deserializers = OrderedDict(( deserializers = OrderedDict((
('vocab', lambda p: self.vocab.from_disk(p)), ('vocab', lambda p: self.vocab.from_disk(p)),
('tokenizer', lambda p: self.tokenizer.from_disk(p, vocab=False)), ('tokenizer', lambda p: self.tokenizer.from_disk(p, vocab=False)),
('meta.json', lambda p: self.meta.update(ujson.load(p.open('r')))) ('meta.json', lambda p: self.meta.update(util.read_json(p)))
)) ))
for name, proc in self.pipeline: for name, proc in self.pipeline:
if name in disable: if name in disable:

View File

@ -214,7 +214,8 @@ class Pipe(object):
def _load_cfg(path): def _load_cfg(path):
if path.exists(): if path.exists():
return ujson.load(path.open()) with path.open() as file_:
return ujson.load(file_)
else: else:
return {} return {}
@ -580,7 +581,8 @@ class Tagger(Pipe):
def load_model(p): def load_model(p):
if self.model is True: if self.model is True:
self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg) self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg)
self.model.from_bytes(p.open('rb').read()) with p.open('rb') as file_:
self.model.from_bytes(file_.read())
def load_tag_map(p): def load_tag_map(p):
with p.open('rb') as file_: with p.open('rb') as file_:

View File

@ -887,7 +887,7 @@ cdef class Parser:
deserializers = { deserializers = {
'vocab': lambda p: self.vocab.from_disk(p), 'vocab': lambda p: self.vocab.from_disk(p),
'moves': lambda p: self.moves.from_disk(p, strings=False), 'moves': lambda p: self.moves.from_disk(p, strings=False),
'cfg': lambda p: self.cfg.update(ujson.load(p.open())), 'cfg': lambda p: self.cfg.update(util.read_json(p)),
'model': lambda p: None 'model': lambda p: None
} }
util.from_disk(path, deserializers, exclude) util.from_disk(path, deserializers, exclude)

View File

@ -0,0 +1,28 @@
# coding: utf-8
from __future__ import unicode_literals
from ..util import make_tempdir
from ...language import Language
import pytest
@pytest.fixture
def meta_data():
return {
'name': 'name-in-fixture',
'version': 'version-in-fixture',
'description': 'description-in-fixture',
'author': 'author-in-fixture',
'email': 'email-in-fixture',
'url': 'url-in-fixture',
'license': 'license-in-fixture',
}
def test_serialize_language_meta_disk(meta_data):
language = Language(meta=meta_data)
with make_tempdir() as d:
language.to_disk(d)
new_language = Language().from_disk(d)
assert new_language.meta == language.meta

View File

@ -347,7 +347,8 @@ cdef class Vectors:
""" """
def load_key2row(path): def load_key2row(path):
if path.exists(): if path.exists():
self.key2row = msgpack.load(path.open('rb')) with path.open('rb') as file_:
self.key2row = msgpack.load(file_)
for key, row in self.key2row.items(): for key, row in self.key2row.items():
if row in self._unset: if row in self._unset:
self._unset.remove(row) self._unset.remove(row)