mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Save vectors as little endian, load with Ops.asarray (#10201)
* Save vectors as little endian, load with Ops.asarray * Always save vector data as little endian * Always run `Vectors.to_ops` when vector data is loaded so that `Ops.asarray` can be used to load the data correctly for the current ops. * Update spacy/vectors.pyx Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Update spacy/vectors.pyx Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
parent
107bab56b5
commit
c17980e535
|
@ -565,8 +565,9 @@ cdef class Vectors:
|
||||||
# the source of numpy.save indicates that the file object is closed after use.
|
# the source of numpy.save indicates that the file object is closed after use.
|
||||||
# but it seems that somehow this does not happen, as ResourceWarnings are raised here.
|
# but it seems that somehow this does not happen, as ResourceWarnings are raised here.
|
||||||
# in order to not rely on this, wrap in context manager.
|
# in order to not rely on this, wrap in context manager.
|
||||||
|
ops = get_current_ops()
|
||||||
with path.open("wb") as _file:
|
with path.open("wb") as _file:
|
||||||
save_array(self.data, _file)
|
save_array(ops.to_numpy(self.data, byte_order="<"), _file)
|
||||||
|
|
||||||
serializers = {
|
serializers = {
|
||||||
"strings": lambda p: self.strings.to_disk(p.with_suffix(".json")),
|
"strings": lambda p: self.strings.to_disk(p.with_suffix(".json")),
|
||||||
|
@ -602,6 +603,7 @@ cdef class Vectors:
|
||||||
ops = get_current_ops()
|
ops = get_current_ops()
|
||||||
if path.exists():
|
if path.exists():
|
||||||
self.data = ops.xp.load(str(path))
|
self.data = ops.xp.load(str(path))
|
||||||
|
self.to_ops(ops)
|
||||||
|
|
||||||
def load_settings(path):
|
def load_settings(path):
|
||||||
if path.exists():
|
if path.exists():
|
||||||
|
@ -631,7 +633,8 @@ cdef class Vectors:
|
||||||
if hasattr(self.data, "to_bytes"):
|
if hasattr(self.data, "to_bytes"):
|
||||||
return self.data.to_bytes()
|
return self.data.to_bytes()
|
||||||
else:
|
else:
|
||||||
return srsly.msgpack_dumps(self.data)
|
ops = get_current_ops()
|
||||||
|
return srsly.msgpack_dumps(ops.to_numpy(self.data, byte_order="<"))
|
||||||
|
|
||||||
serializers = {
|
serializers = {
|
||||||
"strings": lambda: self.strings.to_bytes(),
|
"strings": lambda: self.strings.to_bytes(),
|
||||||
|
@ -656,6 +659,8 @@ cdef class Vectors:
|
||||||
else:
|
else:
|
||||||
xp = get_array_module(self.data)
|
xp = get_array_module(self.data)
|
||||||
self.data = xp.asarray(srsly.msgpack_loads(b))
|
self.data = xp.asarray(srsly.msgpack_loads(b))
|
||||||
|
ops = get_current_ops()
|
||||||
|
self.to_ops(ops)
|
||||||
|
|
||||||
deserializers = {
|
deserializers = {
|
||||||
"strings": lambda b: self.strings.from_bytes(b),
|
"strings": lambda b: self.strings.from_bytes(b),
|
||||||
|
|
Loading…
Reference in New Issue
Block a user