2017-10-27 20:45:19 +03:00
|
|
|
# coding: utf8
|
2017-08-19 22:27:35 +03:00
|
|
|
from __future__ import unicode_literals
|
2017-10-27 20:45:19 +03:00
|
|
|
|
2017-06-05 13:32:08 +03:00
|
|
|
import numpy
|
|
|
|
from collections import OrderedDict
|
|
|
|
import msgpack
|
|
|
|
import msgpack_numpy
|
|
|
|
msgpack_numpy.patch()
|
2017-08-18 21:45:48 +03:00
|
|
|
cimport numpy as np
|
2017-09-16 20:45:09 +03:00
|
|
|
from thinc.neural.util import get_array_module
|
|
|
|
from thinc.neural._classes.model import Model
|
2017-06-05 13:32:08 +03:00
|
|
|
|
2017-11-01 02:34:55 +03:00
|
|
|
from .strings cimport StringStore, hash_string
|
2017-10-16 21:55:00 +03:00
|
|
|
from .compat import basestring_, path2str
|
2017-10-27 20:45:19 +03:00
|
|
|
from . import util
|
2017-06-05 13:32:08 +03:00
|
|
|
|
|
|
|
|
2017-10-31 20:25:08 +03:00
|
|
|
def unpickle_vectors(keys_and_rows, data):
|
|
|
|
vectors = Vectors(data=data)
|
|
|
|
for key, row in keys_and_rows:
|
|
|
|
vectors.add(key, row=row)
|
2017-12-07 11:53:30 +03:00
|
|
|
return vectors
|
2017-10-31 20:25:08 +03:00
|
|
|
|
|
|
|
|
2017-06-05 13:32:08 +03:00
|
|
|
cdef class Vectors:
|
2017-10-27 20:45:19 +03:00
|
|
|
"""Store, save and load word vectors.
|
2017-10-02 01:05:54 +03:00
|
|
|
|
2017-10-01 23:10:33 +03:00
|
|
|
Vectors data is kept in the vectors.data attribute, which should be an
|
2017-10-27 20:45:19 +03:00
|
|
|
instance of numpy.ndarray (for CPU vectors) or cupy.ndarray
|
|
|
|
(for GPU vectors). `vectors.key2row` is a dictionary mapping word hashes to
|
2017-10-30 12:03:08 +03:00
|
|
|
rows in the vectors.data table.
|
2017-11-01 01:23:34 +03:00
|
|
|
|
2017-10-31 20:25:08 +03:00
|
|
|
Multiple keys can be mapped to the same vector, and not all of the rows in
|
|
|
|
the table need to be assigned --- so len(list(vectors.keys())) may be
|
|
|
|
greater or smaller than vectors.shape[0].
|
2017-10-27 20:45:19 +03:00
|
|
|
"""
|
2017-06-05 13:32:08 +03:00
|
|
|
cdef public object data
|
2017-08-19 05:33:03 +03:00
|
|
|
cdef public object key2row
|
2017-10-31 20:25:08 +03:00
|
|
|
cdef public object _unset
|
2017-06-05 13:32:08 +03:00
|
|
|
|
2017-10-31 20:25:08 +03:00
|
|
|
def __init__(self, *, shape=None, data=None, keys=None):
|
|
|
|
"""Create a new vector store.
|
2017-11-01 01:23:34 +03:00
|
|
|
|
2017-10-31 20:25:08 +03:00
|
|
|
shape (tuple): Size of the table, as (# entries, # columns)
|
2017-10-27 20:45:19 +03:00
|
|
|
data (numpy.ndarray): The vector data.
|
2017-11-01 01:23:34 +03:00
|
|
|
keys (iterable): A sequence of keys, aligned with the data.
|
2017-10-27 20:45:19 +03:00
|
|
|
RETURNS (Vectors): The newly created object.
|
|
|
|
"""
|
2017-10-31 20:25:08 +03:00
|
|
|
if data is None:
|
|
|
|
if shape is None:
|
|
|
|
shape = (0,0)
|
|
|
|
data = numpy.zeros(shape, dtype='f')
|
|
|
|
self.data = data
|
|
|
|
self.key2row = OrderedDict()
|
|
|
|
if self.data is not None:
|
|
|
|
self._unset = set(range(self.data.shape[0]))
|
2017-06-05 13:32:08 +03:00
|
|
|
else:
|
2017-10-31 20:25:08 +03:00
|
|
|
self._unset = set()
|
|
|
|
if keys is not None:
|
|
|
|
for i, key in enumerate(keys):
|
|
|
|
self.add(key, row=i)
|
2017-11-01 01:23:34 +03:00
|
|
|
|
2017-10-31 20:25:08 +03:00
|
|
|
@property
|
|
|
|
def shape(self):
|
|
|
|
"""Get `(rows, dims)` tuples of number of rows and number of dimensions
|
|
|
|
in the vector table.
|
|
|
|
|
|
|
|
RETURNS (tuple): A `(rows, dims)` pair.
|
|
|
|
"""
|
|
|
|
return self.data.shape
|
|
|
|
|
|
|
|
@property
|
|
|
|
def size(self):
|
2017-11-01 02:18:08 +03:00
|
|
|
"""RETURNS (int): rows*dims"""
|
2017-10-31 20:25:08 +03:00
|
|
|
return self.data.shape[0] * self.data.shape[1]
|
|
|
|
|
|
|
|
@property
|
|
|
|
def is_full(self):
|
2017-11-01 02:18:08 +03:00
|
|
|
"""RETURNS (bool): `True` if no slots are available for new keys."""
|
2017-10-31 20:25:08 +03:00
|
|
|
return len(self._unset) == 0
|
2017-06-05 13:32:08 +03:00
|
|
|
|
2017-10-31 21:30:52 +03:00
|
|
|
@property
|
|
|
|
def n_keys(self):
|
2017-11-01 02:18:08 +03:00
|
|
|
"""RETURNS (int) The number of keys in the table. Note that this is the
|
|
|
|
number of all keys, not just unique vectors."""
|
2017-10-31 21:30:52 +03:00
|
|
|
return len(self.key2row)
|
|
|
|
|
2017-06-05 13:32:08 +03:00
|
|
|
def __reduce__(self):
|
2017-12-05 14:45:24 +03:00
|
|
|
keys_and_rows = tuple(self.key2row.items())
|
2017-10-31 20:25:08 +03:00
|
|
|
return (unpickle_vectors, (keys_and_rows, self.data))
|
2017-06-05 13:32:08 +03:00
|
|
|
|
|
|
|
def __getitem__(self, key):
|
2017-10-31 20:25:08 +03:00
|
|
|
"""Get a vector by key. If the key is not found, a KeyError is raised.
|
2017-10-01 23:10:33 +03:00
|
|
|
|
2017-10-31 20:25:08 +03:00
|
|
|
key (int): The key to get the vector for.
|
|
|
|
RETURNS (ndarray): The vector for the key.
|
2017-10-27 20:45:19 +03:00
|
|
|
"""
|
2017-08-19 05:33:03 +03:00
|
|
|
i = self.key2row[key]
|
2017-06-05 13:32:08 +03:00
|
|
|
if i is None:
|
|
|
|
raise KeyError(key)
|
|
|
|
else:
|
|
|
|
return self.data[i]
|
|
|
|
|
|
|
|
def __setitem__(self, key, vector):
|
2017-10-31 20:25:08 +03:00
|
|
|
"""Set a vector for the given key.
|
2017-10-27 20:45:19 +03:00
|
|
|
|
2017-10-31 20:25:08 +03:00
|
|
|
key (int): The key to set the vector for.
|
2017-11-01 01:23:34 +03:00
|
|
|
vector (ndarray): The vector to set.
|
2017-10-27 20:45:19 +03:00
|
|
|
"""
|
2017-08-19 05:33:03 +03:00
|
|
|
i = self.key2row[key]
|
2017-06-05 13:32:08 +03:00
|
|
|
self.data[i] = vector
|
2017-10-31 20:25:08 +03:00
|
|
|
if i in self._unset:
|
|
|
|
self._unset.remove(i)
|
2017-06-05 13:32:08 +03:00
|
|
|
|
|
|
|
def __iter__(self):
|
2017-11-01 01:23:34 +03:00
|
|
|
"""Iterate over the keys in the table.
|
2017-10-27 20:45:19 +03:00
|
|
|
|
2017-11-01 01:23:34 +03:00
|
|
|
YIELDS (int): A key in the table.
|
2017-10-27 20:45:19 +03:00
|
|
|
"""
|
2017-10-31 20:25:08 +03:00
|
|
|
yield from self.key2row
|
2017-06-05 13:32:08 +03:00
|
|
|
|
|
|
|
def __len__(self):
|
2017-10-31 20:25:08 +03:00
|
|
|
"""Return the number of vectors in the table.
|
2017-10-27 20:45:19 +03:00
|
|
|
|
|
|
|
RETURNS (int): The number of vectors in the data.
|
|
|
|
"""
|
2017-10-31 20:25:08 +03:00
|
|
|
return self.data.shape[0]
|
2017-08-19 20:52:25 +03:00
|
|
|
|
|
|
|
def __contains__(self, key):
|
2017-10-31 20:25:08 +03:00
|
|
|
"""Check whether a key has been mapped to a vector entry in the table.
|
2017-10-27 20:45:19 +03:00
|
|
|
|
2017-10-31 20:25:08 +03:00
|
|
|
key (int): The key to check.
|
2017-10-27 20:45:19 +03:00
|
|
|
RETURNS (bool): Whether the key has a vector entry.
|
|
|
|
"""
|
2017-08-19 20:52:25 +03:00
|
|
|
return key in self.key2row
|
|
|
|
|
2017-10-31 20:25:08 +03:00
|
|
|
def resize(self, shape, inplace=False):
|
2017-11-01 01:23:34 +03:00
|
|
|
"""Resize the underlying vectors array. If inplace=True, the memory
|
2017-10-31 20:25:08 +03:00
|
|
|
is reallocated. This may cause other references to the data to become
|
|
|
|
invalid, so only use inplace=True if you're sure that's what you want.
|
|
|
|
|
|
|
|
If the number of vectors is reduced, keys mapped to rows that have been
|
|
|
|
deleted are removed. These removed items are returned as a list of
|
2017-11-01 01:23:34 +03:00
|
|
|
`(key, row)` tuples.
|
|
|
|
"""
|
2017-10-31 20:25:08 +03:00
|
|
|
if inplace:
|
|
|
|
self.data.resize(shape, refcheck=False)
|
|
|
|
else:
|
|
|
|
xp = get_array_module(self.data)
|
|
|
|
self.data = xp.resize(self.data, shape)
|
|
|
|
filled = {row for row in self.key2row.values()}
|
|
|
|
self._unset = {row for row in range(shape[0]) if row not in filled}
|
|
|
|
removed_items = []
|
2018-01-14 16:48:51 +03:00
|
|
|
for key, row in list(self.key2row.items()):
|
2017-10-31 20:25:08 +03:00
|
|
|
if row >= shape[0]:
|
|
|
|
self.key2row.pop(key)
|
|
|
|
removed_items.append((key, row))
|
|
|
|
return removed_items
|
2017-11-01 01:23:34 +03:00
|
|
|
|
2017-10-31 20:25:08 +03:00
|
|
|
def keys(self):
|
2017-11-01 01:23:34 +03:00
|
|
|
"""A sequence of the keys in the table.
|
|
|
|
|
|
|
|
RETURNS (iterable): The keys.
|
|
|
|
"""
|
|
|
|
return self.key2row.keys()
|
|
|
|
|
2017-10-31 20:25:08 +03:00
|
|
|
def values(self):
|
2017-11-01 01:23:34 +03:00
|
|
|
"""Iterate over vectors that have been assigned to at least one key.
|
2017-10-31 20:25:08 +03:00
|
|
|
|
|
|
|
Note that some vectors may be unassigned, so the number of vectors
|
2017-11-01 01:23:34 +03:00
|
|
|
returned may be less than the length of the vectors table.
|
|
|
|
|
|
|
|
YIELDS (ndarray): A vector in the table.
|
|
|
|
"""
|
2017-10-31 20:25:08 +03:00
|
|
|
for row, vector in enumerate(range(self.data.shape[0])):
|
|
|
|
if row not in self._unset:
|
|
|
|
yield vector
|
|
|
|
|
|
|
|
def items(self):
|
|
|
|
"""Iterate over `(key, vector)` pairs.
|
|
|
|
|
|
|
|
YIELDS (tuple): A key/vector pair.
|
|
|
|
"""
|
|
|
|
for key, row in self.key2row.items():
|
|
|
|
yield key, self.data[row]
|
|
|
|
|
2017-11-01 02:34:55 +03:00
|
|
|
def find(self, *, key=None, keys=None, row=None, rows=None):
|
2017-11-01 02:42:39 +03:00
|
|
|
"""Look up one or more keys by row, or vice versa.
|
2017-11-01 02:34:55 +03:00
|
|
|
|
|
|
|
key (unicode / int): Find the row that the given key points to.
|
|
|
|
Returns int, -1 if missing.
|
2017-11-01 02:42:39 +03:00
|
|
|
keys (iterable): Find rows that the keys point to.
|
2017-11-01 02:34:55 +03:00
|
|
|
Returns ndarray.
|
|
|
|
row (int): Find the first key that point to the row.
|
|
|
|
Returns int.
|
2017-11-01 02:42:39 +03:00
|
|
|
rows (iterable): Find the keys that point to the rows.
|
2017-11-01 02:34:55 +03:00
|
|
|
Returns ndarray.
|
2017-11-01 02:42:39 +03:00
|
|
|
RETURNS: The requested key, keys, row or rows.
|
|
|
|
"""
|
2017-11-01 02:34:55 +03:00
|
|
|
if sum(arg is None for arg in (key, keys, row, rows)) != 3:
|
|
|
|
raise ValueError("One (and only one) keyword arg must be set.")
|
2017-10-31 20:25:08 +03:00
|
|
|
xp = get_array_module(self.data)
|
2017-11-01 02:34:55 +03:00
|
|
|
if key is not None:
|
|
|
|
if isinstance(key, basestring_):
|
|
|
|
key = hash_string(key)
|
|
|
|
return self.key2row.get(key, -1)
|
|
|
|
elif keys is not None:
|
|
|
|
keys = [hash_string(key) if isinstance(key, basestring_) else key
|
|
|
|
for key in keys]
|
|
|
|
rows = [self.key2row.get(key, -1.) for key in keys]
|
|
|
|
return xp.asarray(rows, dtype='i')
|
|
|
|
else:
|
|
|
|
targets = set()
|
|
|
|
if row is not None:
|
|
|
|
targets.add(row)
|
|
|
|
else:
|
|
|
|
targets.update(rows)
|
|
|
|
results = []
|
|
|
|
for key, row in self.key2row.items():
|
|
|
|
if row in targets:
|
|
|
|
results.append(key)
|
|
|
|
targets.remove(row)
|
|
|
|
return xp.asarray(results, dtype='uint64')
|
2017-10-31 20:25:08 +03:00
|
|
|
|
2017-10-30 12:03:08 +03:00
|
|
|
def add(self, key, *, vector=None, row=None):
|
|
|
|
"""Add a key to the table. Keys can be mapped to an existing vector
|
|
|
|
by setting `row`, or a new vector can be added.
|
2017-10-27 20:45:19 +03:00
|
|
|
|
2017-11-01 02:18:08 +03:00
|
|
|
key (int): The key to add.
|
|
|
|
vector (ndarray / None): A vector to add for the key.
|
|
|
|
row (int / None): The row number of a vector to map the key to.
|
|
|
|
RETURNS (int): The row the vector was added to.
|
2017-10-27 20:45:19 +03:00
|
|
|
"""
|
2017-11-01 02:34:55 +03:00
|
|
|
if isinstance(key, basestring):
|
|
|
|
key = hash_string(key)
|
2017-10-31 04:00:26 +03:00
|
|
|
if row is None and key in self.key2row:
|
|
|
|
row = self.key2row[key]
|
|
|
|
elif row is None:
|
2017-10-31 20:25:08 +03:00
|
|
|
if self.is_full:
|
|
|
|
raise ValueError("Cannot add new key to vectors -- full")
|
|
|
|
row = min(self._unset)
|
2017-10-31 04:00:26 +03:00
|
|
|
|
|
|
|
self.key2row[key] = row
|
2017-08-19 20:52:25 +03:00
|
|
|
if vector is not None:
|
2017-10-30 12:03:08 +03:00
|
|
|
self.data[row] = vector
|
2017-10-31 20:25:08 +03:00
|
|
|
if row in self._unset:
|
|
|
|
self._unset.remove(row)
|
2017-10-30 12:03:08 +03:00
|
|
|
return row
|
2017-11-01 01:23:34 +03:00
|
|
|
|
2017-11-01 02:18:08 +03:00
|
|
|
def most_similar(self, queries, *, batch_size=1024):
|
|
|
|
"""For each of the given vectors, find the single entry most similar
|
2017-10-31 20:25:08 +03:00
|
|
|
to it, by cosine.
|
2017-11-01 01:23:34 +03:00
|
|
|
|
2017-11-01 02:18:08 +03:00
|
|
|
Queries are by vector. Results are returned as a `(keys, best_rows,
|
|
|
|
scores)` tuple. If `queries` is large, the calculations are performed in
|
|
|
|
chunks, to avoid consuming too much memory. You can set the `batch_size`
|
|
|
|
to control the size/space trade-off during the calculations.
|
|
|
|
|
|
|
|
queries (ndarray): An array with one or more vectors.
|
|
|
|
batch_size (int): The batch size to use.
|
|
|
|
RETURNS (tuple): The most similar entry as a `(keys, best_rows, scores)`
|
|
|
|
tuple.
|
|
|
|
"""
|
2017-10-31 20:25:08 +03:00
|
|
|
xp = get_array_module(self.data)
|
2017-11-01 01:23:34 +03:00
|
|
|
|
2017-10-31 20:25:08 +03:00
|
|
|
vectors = self.data / xp.linalg.norm(self.data, axis=1, keepdims=True)
|
2017-11-01 01:23:34 +03:00
|
|
|
|
2017-10-31 20:25:08 +03:00
|
|
|
best_rows = xp.zeros((queries.shape[0],), dtype='i')
|
|
|
|
scores = xp.zeros((queries.shape[0],), dtype='f')
|
|
|
|
# Work in batches, to avoid memory problems.
|
|
|
|
for i in range(0, queries.shape[0], batch_size):
|
|
|
|
batch = queries[i : i+batch_size]
|
|
|
|
batch /= xp.linalg.norm(batch, axis=1, keepdims=True)
|
|
|
|
# batch e.g. (1024, 300)
|
|
|
|
# vectors e.g. (10000, 300)
|
|
|
|
# sims e.g. (1024, 10000)
|
|
|
|
sims = xp.dot(batch, vectors.T)
|
|
|
|
best_rows[i:i+batch_size] = sims.argmax(axis=1)
|
|
|
|
scores[i:i+batch_size] = sims.max(axis=1)
|
2017-11-01 04:06:58 +03:00
|
|
|
|
|
|
|
xp = get_array_module(self.data)
|
|
|
|
row2key = {row: key for key, row in self.key2row.items()}
|
|
|
|
keys = xp.asarray([row2key[row] for row in best_rows], dtype='uint64')
|
2017-11-01 02:18:08 +03:00
|
|
|
return (keys, best_rows, scores)
|
2017-06-05 13:32:08 +03:00
|
|
|
|
2017-09-01 17:39:22 +03:00
|
|
|
def from_glove(self, path):
|
2017-10-27 20:45:19 +03:00
|
|
|
"""Load GloVe vectors from a directory. Assumes binary format,
|
2017-09-01 17:39:22 +03:00
|
|
|
that the vocab is in a vocab.txt, and that vectors are named
|
|
|
|
vectors.{size}.[fd].bin, e.g. vectors.128.f.bin for 128d float32
|
|
|
|
vectors, vectors.300.d.bin for 300d float64 (double) vectors, etc.
|
2017-10-27 20:45:19 +03:00
|
|
|
By default GloVe outputs 64-bit vectors.
|
|
|
|
|
|
|
|
path (unicode / Path): The path to load the GloVe vectors from.
|
2017-11-01 02:18:08 +03:00
|
|
|
RETURNS: A `StringStore` object, holding the key-to-string mapping.
|
2017-10-27 20:45:19 +03:00
|
|
|
"""
|
2017-09-01 17:39:22 +03:00
|
|
|
path = util.ensure_path(path)
|
2017-10-31 20:25:08 +03:00
|
|
|
width = None
|
2017-09-01 17:39:22 +03:00
|
|
|
for name in path.iterdir():
|
|
|
|
if name.parts[-1].startswith('vectors'):
|
|
|
|
_, dims, dtype, _2 = name.parts[-1].split('.')
|
2017-10-31 20:25:08 +03:00
|
|
|
width = int(dims)
|
2017-09-01 17:39:22 +03:00
|
|
|
break
|
|
|
|
else:
|
|
|
|
raise IOError("Expected file named e.g. vectors.128.f.bin")
|
|
|
|
bin_loc = path / 'vectors.{dims}.{dtype}.bin'.format(dims=dims,
|
|
|
|
dtype=dtype)
|
2017-10-31 20:25:08 +03:00
|
|
|
xp = get_array_module(self.data)
|
|
|
|
self.data = None
|
2017-09-01 17:39:22 +03:00
|
|
|
with bin_loc.open('rb') as file_:
|
2017-10-31 20:25:08 +03:00
|
|
|
self.data = xp.fromfile(file_, dtype=dtype)
|
|
|
|
if dtype != 'float32':
|
|
|
|
self.data = xp.ascontiguousarray(self.data, dtype='float32')
|
2017-09-01 17:39:22 +03:00
|
|
|
n = 0
|
2017-10-31 20:25:08 +03:00
|
|
|
strings = StringStore()
|
2017-09-01 17:39:22 +03:00
|
|
|
with (path / 'vocab.txt').open('r') as file_:
|
2017-10-31 20:25:08 +03:00
|
|
|
for i, line in enumerate(file_):
|
|
|
|
key = strings.add(line.strip())
|
|
|
|
self.add(key, row=i)
|
|
|
|
return strings
|
2017-09-01 17:39:22 +03:00
|
|
|
|
2017-08-18 21:45:48 +03:00
|
|
|
def to_disk(self, path, **exclude):
|
2017-10-27 20:45:19 +03:00
|
|
|
"""Save the current state to a directory.
|
|
|
|
|
|
|
|
path (unicode / Path): A path to a directory, which will be created if
|
|
|
|
it doesn't exists. Either a string or a Path-like object.
|
|
|
|
"""
|
2017-09-16 20:45:09 +03:00
|
|
|
xp = get_array_module(self.data)
|
|
|
|
if xp is numpy:
|
2017-10-27 20:45:19 +03:00
|
|
|
save_array = lambda arr, file_: xp.save(file_, arr,
|
|
|
|
allow_pickle=False)
|
2017-09-16 20:45:09 +03:00
|
|
|
else:
|
|
|
|
save_array = lambda arr, file_: xp.save(file_, arr)
|
2017-08-18 21:45:48 +03:00
|
|
|
serializers = OrderedDict((
|
2017-09-16 20:45:09 +03:00
|
|
|
('vectors', lambda p: save_array(self.data, p.open('wb'))),
|
2017-10-31 04:00:26 +03:00
|
|
|
('key2row', lambda p: msgpack.dump(self.key2row, p.open('wb')))
|
2017-08-18 21:45:48 +03:00
|
|
|
))
|
2017-08-19 19:42:11 +03:00
|
|
|
return util.to_disk(path, serializers, exclude)
|
2017-08-18 21:45:48 +03:00
|
|
|
|
|
|
|
def from_disk(self, path, **exclude):
|
2017-10-27 20:45:19 +03:00
|
|
|
"""Loads state from a directory. Modifies the object in place and
|
|
|
|
returns it.
|
|
|
|
|
|
|
|
path (unicode / Path): Directory path, string or Path-like object.
|
|
|
|
RETURNS (Vectors): The modified object.
|
|
|
|
"""
|
2017-10-31 21:58:35 +03:00
|
|
|
def load_key2row(path):
|
2017-08-19 23:07:00 +03:00
|
|
|
if path.exists():
|
2017-10-31 04:00:26 +03:00
|
|
|
self.key2row = msgpack.load(path.open('rb'))
|
2017-10-31 21:58:35 +03:00
|
|
|
for key, row in self.key2row.items():
|
|
|
|
if row in self._unset:
|
|
|
|
self._unset.remove(row)
|
|
|
|
|
|
|
|
def load_keys(path):
|
|
|
|
if path.exists():
|
|
|
|
keys = numpy.load(str(path))
|
|
|
|
for i, key in enumerate(keys):
|
|
|
|
self.add(key, row=i)
|
2017-08-19 19:42:11 +03:00
|
|
|
|
|
|
|
def load_vectors(path):
|
2017-09-16 20:45:09 +03:00
|
|
|
xp = Model.ops.xp
|
2017-08-19 23:07:00 +03:00
|
|
|
if path.exists():
|
2017-11-05 16:42:46 +03:00
|
|
|
self.data = xp.load(str(path))
|
2017-08-18 21:45:48 +03:00
|
|
|
|
|
|
|
serializers = OrderedDict((
|
2017-10-31 21:58:35 +03:00
|
|
|
('key2row', load_key2row),
|
|
|
|
('keys', load_keys),
|
2017-08-19 19:42:11 +03:00
|
|
|
('vectors', load_vectors),
|
2017-08-18 21:45:48 +03:00
|
|
|
))
|
2017-08-19 19:42:11 +03:00
|
|
|
util.from_disk(path, serializers, exclude)
|
|
|
|
return self
|
2017-06-05 13:32:08 +03:00
|
|
|
|
|
|
|
def to_bytes(self, **exclude):
|
2017-10-27 20:45:19 +03:00
|
|
|
"""Serialize the current state to a binary string.
|
|
|
|
|
|
|
|
**exclude: Named attributes to prevent from being serialized.
|
|
|
|
RETURNS (bytes): The serialized form of the `Vectors` object.
|
|
|
|
"""
|
2017-06-05 13:32:08 +03:00
|
|
|
def serialize_weights():
|
2017-08-18 21:45:48 +03:00
|
|
|
if hasattr(self.data, 'to_bytes'):
|
|
|
|
return self.data.to_bytes()
|
2017-06-05 13:32:08 +03:00
|
|
|
else:
|
2017-08-18 21:45:48 +03:00
|
|
|
return msgpack.dumps(self.data)
|
2017-06-05 13:32:08 +03:00
|
|
|
serializers = OrderedDict((
|
2017-10-31 04:00:26 +03:00
|
|
|
('key2row', lambda: msgpack.dumps(self.key2row)),
|
2017-08-18 21:45:48 +03:00
|
|
|
('vectors', serialize_weights)
|
2017-06-05 13:32:08 +03:00
|
|
|
))
|
|
|
|
return util.to_bytes(serializers, exclude)
|
|
|
|
|
|
|
|
def from_bytes(self, data, **exclude):
|
2017-10-27 20:45:19 +03:00
|
|
|
"""Load state from a binary string.
|
|
|
|
|
|
|
|
data (bytes): The data to load from.
|
|
|
|
**exclude: Named attributes to prevent from being loaded.
|
|
|
|
RETURNS (Vectors): The `Vectors` object.
|
|
|
|
"""
|
2017-06-05 13:32:08 +03:00
|
|
|
def deserialize_weights(b):
|
2017-08-18 21:45:48 +03:00
|
|
|
if hasattr(self.data, 'from_bytes'):
|
|
|
|
self.data.from_bytes()
|
2017-06-05 13:32:08 +03:00
|
|
|
else:
|
2017-08-18 21:45:48 +03:00
|
|
|
self.data = msgpack.loads(b)
|
2017-06-05 13:32:08 +03:00
|
|
|
|
|
|
|
deserializers = OrderedDict((
|
2017-10-31 04:00:26 +03:00
|
|
|
('key2row', lambda b: self.key2row.update(msgpack.loads(b))),
|
2017-08-18 21:45:48 +03:00
|
|
|
('vectors', deserialize_weights)
|
2017-06-05 13:32:08 +03:00
|
|
|
))
|
2017-08-19 21:35:33 +03:00
|
|
|
util.from_bytes(data, deserializers, exclude)
|
2017-08-19 19:42:11 +03:00
|
|
|
return self
|