issue5230: replaced open statements on path objects so that serialization still works an files are closed

This commit is contained in:
Leander Fiedler 2020-04-06 20:30:41 +02:00 committed by lfiedler
parent 273ed452bb
commit 71cc903d65
3 changed files with 12 additions and 8 deletions

View File

@ -202,7 +202,7 @@ class Pipe(object):
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg) serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
serialize["vocab"] = lambda p: self.vocab.to_disk(p) serialize["vocab"] = lambda p: self.vocab.to_disk(p)
if self.model not in (None, True, False): if self.model not in (None, True, False):
serialize["model"] = lambda p: p.open("wb").write(self.model.to_bytes()) serialize["model"] = self.model.to_disk
exclude = util.get_serialization_exclude(serialize, exclude, kwargs) exclude = util.get_serialization_exclude(serialize, exclude, kwargs)
util.to_disk(path, serialize, exclude) util.to_disk(path, serialize, exclude)
@ -625,7 +625,7 @@ class Tagger(Pipe):
serialize = OrderedDict(( serialize = OrderedDict((
("vocab", lambda p: self.vocab.to_disk(p)), ("vocab", lambda p: self.vocab.to_disk(p)),
("tag_map", lambda p: srsly.write_msgpack(p, tag_map)), ("tag_map", lambda p: srsly.write_msgpack(p, tag_map)),
("model", lambda p: p.open("wb").write(self.model.to_bytes())), ("model", self.model.to_disk),
("cfg", lambda p: srsly.write_json(p, self.cfg)) ("cfg", lambda p: srsly.write_json(p, self.cfg))
)) ))
exclude = util.get_serialization_exclude(serialize, exclude, kwargs) exclude = util.get_serialization_exclude(serialize, exclude, kwargs)
@ -1394,7 +1394,7 @@ class EntityLinker(Pipe):
serialize["vocab"] = lambda p: self.vocab.to_disk(p) serialize["vocab"] = lambda p: self.vocab.to_disk(p)
serialize["kb"] = lambda p: self.kb.dump(p) serialize["kb"] = lambda p: self.kb.dump(p)
if self.model not in (None, True, False): if self.model not in (None, True, False):
serialize["model"] = lambda p: p.open("wb").write(self.model.to_bytes()) serialize["model"] = self.model.to_disk
exclude = util.get_serialization_exclude(serialize, exclude, kwargs) exclude = util.get_serialization_exclude(serialize, exclude, kwargs)
util.to_disk(path, serialize, exclude) util.to_disk(path, serialize, exclude)

View File

@ -24,7 +24,6 @@ def test_language_to_disk_resource_warning():
assert len(w) == 0 assert len(w) == 0
@pytest.mark.xfail
def test_vectors_to_disk_resource_warning(): def test_vectors_to_disk_resource_warning():
data = numpy.zeros((3, 300), dtype="f") data = numpy.zeros((3, 300), dtype="f")
keys = ["cat", "dog", "rat"] keys = ["cat", "dog", "rat"]
@ -36,7 +35,6 @@ def test_vectors_to_disk_resource_warning():
assert len(w) == 0 assert len(w) == 0
@pytest.mark.xfail
def test_custom_pipes_to_disk_resource_warning(): def test_custom_pipes_to_disk_resource_warning():
# create dummy pipe partially implementing interface -- only want to test to_disk # create dummy pipe partially implementing interface -- only want to test to_disk
class SerializableDummy(object): class SerializableDummy(object):
@ -76,7 +74,6 @@ def test_custom_pipes_to_disk_resource_warning():
assert len(w) == 0 assert len(w) == 0
@pytest.mark.xfail
def test_tagger_to_disk_resource_warning(): def test_tagger_to_disk_resource_warning():
nlp = Language() nlp = Language()
nlp.add_pipe(nlp.create_pipe("tagger")) nlp.add_pipe(nlp.create_pipe("tagger"))
@ -93,7 +90,6 @@ def test_tagger_to_disk_resource_warning():
assert len(w) == 0 assert len(w) == 0
@pytest.mark.xfail
def test_entity_linker_to_disk_resource_warning(): def test_entity_linker_to_disk_resource_warning():
nlp = Language() nlp = Language()
nlp.add_pipe(nlp.create_pipe("entity_linker")) nlp.add_pipe(nlp.create_pipe("entity_linker"))

View File

@ -376,8 +376,16 @@ cdef class Vectors:
save_array = lambda arr, file_: xp.save(file_, arr, allow_pickle=False) save_array = lambda arr, file_: xp.save(file_, arr, allow_pickle=False)
else: else:
save_array = lambda arr, file_: xp.save(file_, arr) save_array = lambda arr, file_: xp.save(file_, arr)
def save_vectors(path):
# the source of numpy.save indicates that the file object is closed after use.
# but it seems that somehow this does not happen, as ResourceWarnings are raised here.
# in order to not rely on this, wrap in context manager.
with path.open("wb") as _file:
save_array(self.data, _file)
serializers = OrderedDict(( serializers = OrderedDict((
("vectors", lambda p: save_array(self.data, p.open("wb"))), ("vectors", save_vectors),
("key2row", lambda p: srsly.write_msgpack(p, self.key2row)) ("key2row", lambda p: srsly.write_msgpack(p, self.key2row))
)) ))
return util.to_disk(path, serializers, []) return util.to_disk(path, serializers, [])