2017-10-27 20:45:19 +03:00
|
|
|
# coding: utf8
|
2017-08-19 22:27:35 +03:00
|
|
|
from __future__ import unicode_literals
|
2017-10-27 20:45:19 +03:00
|
|
|
|
2017-06-05 13:32:08 +03:00
|
|
|
import numpy
|
|
|
|
from collections import OrderedDict
|
|
|
|
import msgpack
|
|
|
|
import msgpack_numpy
|
|
|
|
msgpack_numpy.patch()
|
2017-08-18 21:45:48 +03:00
|
|
|
cimport numpy as np
|
2017-09-16 20:45:09 +03:00
|
|
|
from thinc.neural.util import get_array_module
|
|
|
|
from thinc.neural._classes.model import Model
|
2017-06-05 13:32:08 +03:00
|
|
|
|
|
|
|
from .strings cimport StringStore
|
2017-10-16 21:55:00 +03:00
|
|
|
from .compat import basestring_, path2str
|
2017-10-27 20:45:19 +03:00
|
|
|
from . import util
|
2017-06-05 13:32:08 +03:00
|
|
|
|
|
|
|
|
|
|
|
cdef class Vectors:
|
2017-10-27 20:45:19 +03:00
|
|
|
"""Store, save and load word vectors.
|
2017-10-02 01:05:54 +03:00
|
|
|
|
2017-10-01 23:10:33 +03:00
|
|
|
Vectors data is kept in the vectors.data attribute, which should be an
|
2017-10-27 20:45:19 +03:00
|
|
|
instance of numpy.ndarray (for CPU vectors) or cupy.ndarray
|
|
|
|
(for GPU vectors). `vectors.key2row` is a dictionary mapping word hashes to
|
2017-10-30 12:03:08 +03:00
|
|
|
rows in the vectors.data table.
|
|
|
|
|
|
|
|
Multiple keys can be mapped to the same vector, so len(keys) may be greater
|
|
|
|
(but not smaller) than data.shape[0].
|
2017-10-27 20:45:19 +03:00
|
|
|
"""
|
2017-06-05 13:32:08 +03:00
|
|
|
cdef public object data
|
|
|
|
cdef readonly StringStore strings
|
2017-08-19 05:33:03 +03:00
|
|
|
cdef public object key2row
|
2017-08-19 19:42:11 +03:00
|
|
|
cdef public object keys
|
2017-08-19 21:35:33 +03:00
|
|
|
cdef public int i
|
2017-06-05 13:32:08 +03:00
|
|
|
|
2017-10-20 15:19:04 +03:00
|
|
|
def __init__(self, strings, width=0, data=None):
|
2017-10-27 20:45:19 +03:00
|
|
|
"""Create a new vector store. To keep the vector table empty, pass
|
|
|
|
`width=0`. You can also create the vector table and add vectors one by
|
|
|
|
one, or set the vector values directly on initialisation.
|
|
|
|
|
|
|
|
strings (StringStore or list): List of strings or StringStore that maps
|
|
|
|
strings to hash values, and vice versa.
|
|
|
|
width (int): Number of dimensions.
|
|
|
|
data (numpy.ndarray): The vector data.
|
|
|
|
RETURNS (Vectors): The newly created object.
|
|
|
|
"""
|
2017-10-02 01:00:34 +03:00
|
|
|
if isinstance(strings, StringStore):
|
|
|
|
self.strings = strings
|
|
|
|
else:
|
|
|
|
self.strings = StringStore()
|
|
|
|
for string in strings:
|
|
|
|
self.strings.add(string)
|
2017-10-20 14:59:24 +03:00
|
|
|
if data is not None:
|
|
|
|
self.data = numpy.asarray(data, dtype='f')
|
2017-06-05 13:32:08 +03:00
|
|
|
else:
|
2017-10-20 14:59:24 +03:00
|
|
|
self.data = numpy.zeros((len(self.strings), width), dtype='f')
|
2017-08-19 21:35:33 +03:00
|
|
|
self.i = 0
|
2017-08-19 05:33:03 +03:00
|
|
|
self.key2row = {}
|
2017-10-20 14:59:24 +03:00
|
|
|
self.keys = numpy.zeros((self.data.shape[0],), dtype='uint64')
|
2017-10-30 18:08:09 +03:00
|
|
|
if data is not None:
|
|
|
|
for i, string in enumerate(self.strings):
|
|
|
|
if i >= self.data.shape[0]:
|
|
|
|
break
|
|
|
|
self.add(self.strings[string], vector=self.data[i])
|
2017-06-05 13:32:08 +03:00
|
|
|
|
|
|
|
def __reduce__(self):
|
2017-06-05 13:36:04 +03:00
|
|
|
return (Vectors, (self.strings, self.data))
|
2017-06-05 13:32:08 +03:00
|
|
|
|
|
|
|
def __getitem__(self, key):
|
2017-10-27 20:45:19 +03:00
|
|
|
"""Get a vector by key. If key is a string, it is hashed to an integer
|
|
|
|
ID using the vectors.strings table. If the integer key is not found in
|
|
|
|
the table, a KeyError is raised.
|
2017-10-01 23:10:33 +03:00
|
|
|
|
2017-10-27 20:45:19 +03:00
|
|
|
key (unicode / int): The key to get the vector for.
|
|
|
|
RETURNS (numpy.ndarray): The vector for the key.
|
|
|
|
"""
|
2017-06-05 13:32:08 +03:00
|
|
|
if isinstance(key, basestring):
|
|
|
|
key = self.strings[key]
|
2017-08-19 05:33:03 +03:00
|
|
|
i = self.key2row[key]
|
2017-06-05 13:32:08 +03:00
|
|
|
if i is None:
|
|
|
|
raise KeyError(key)
|
|
|
|
else:
|
|
|
|
return self.data[i]
|
|
|
|
|
|
|
|
def __setitem__(self, key, vector):
|
2017-10-27 20:45:19 +03:00
|
|
|
"""Set a vector for the given key. If key is a string, it is hashed
|
2017-10-01 23:10:33 +03:00
|
|
|
to an integer ID using the vectors.strings table.
|
2017-10-27 20:45:19 +03:00
|
|
|
|
|
|
|
key (unicode / int): The key to set the vector for.
|
|
|
|
vector (numpy.ndarray): The vector to set.
|
|
|
|
"""
|
2017-06-05 13:32:08 +03:00
|
|
|
if isinstance(key, basestring):
|
|
|
|
key = self.strings.add(key)
|
2017-08-19 05:33:03 +03:00
|
|
|
i = self.key2row[key]
|
2017-06-05 13:32:08 +03:00
|
|
|
self.data[i] = vector
|
|
|
|
|
|
|
|
def __iter__(self):
|
2017-10-27 20:45:19 +03:00
|
|
|
"""Yield vectors from the table.
|
|
|
|
|
|
|
|
YIELDS (numpy.ndarray): A vector.
|
|
|
|
"""
|
2017-06-05 13:32:08 +03:00
|
|
|
yield from self.data
|
|
|
|
|
|
|
|
def __len__(self):
|
2017-10-27 20:45:19 +03:00
|
|
|
"""Return the number of vectors that have been assigned.
|
|
|
|
|
|
|
|
RETURNS (int): The number of vectors in the data.
|
|
|
|
"""
|
2017-08-19 21:35:33 +03:00
|
|
|
return self.i
|
2017-08-19 20:52:25 +03:00
|
|
|
|
|
|
|
def __contains__(self, key):
|
2017-10-27 20:45:19 +03:00
|
|
|
"""Check whether a key has a vector entry in the table.
|
|
|
|
|
|
|
|
key (unicode / int): The key to check.
|
|
|
|
RETURNS (bool): Whether the key has a vector entry.
|
|
|
|
"""
|
2017-08-19 20:52:25 +03:00
|
|
|
if isinstance(key, basestring_):
|
|
|
|
key = self.strings[key]
|
|
|
|
return key in self.key2row
|
|
|
|
|
2017-10-30 12:03:08 +03:00
|
|
|
def add(self, key, *, vector=None, row=None):
|
|
|
|
"""Add a key to the table. Keys can be mapped to an existing vector
|
|
|
|
by setting `row`, or a new vector can be added.
|
2017-10-27 20:45:19 +03:00
|
|
|
|
|
|
|
key (unicode / int): The key to add.
|
2017-10-30 12:03:08 +03:00
|
|
|
vector (numpy.ndarray / None): A vector to add for the key.
|
|
|
|
row (int / None): The row-number of a vector to map the key to.
|
2017-10-27 20:45:19 +03:00
|
|
|
"""
|
2017-08-19 21:35:33 +03:00
|
|
|
if isinstance(key, basestring_):
|
|
|
|
key = self.strings.add(key)
|
2017-10-30 18:08:09 +03:00
|
|
|
if key in self.key2row and row is None:
|
2017-10-30 12:03:08 +03:00
|
|
|
row = self.key2row[key]
|
|
|
|
elif key in self.key2row and row is not None:
|
|
|
|
self.key2row[key] = row
|
2017-10-30 18:08:09 +03:00
|
|
|
elif row is None:
|
|
|
|
row = self.i
|
2017-08-19 21:35:33 +03:00
|
|
|
self.i += 1
|
2017-10-30 18:08:09 +03:00
|
|
|
if row >= self.keys.shape[0]:
|
|
|
|
self.keys.resize((row*2,))
|
|
|
|
self.data.resize((row*2, self.data.shape[1]))
|
2017-10-30 18:22:03 +03:00
|
|
|
self.keys[row] = key
|
2017-10-30 18:08:09 +03:00
|
|
|
|
|
|
|
self.key2row[key] = row
|
|
|
|
self.keys[row] = key
|
2017-08-19 20:52:25 +03:00
|
|
|
if vector is not None:
|
2017-10-30 12:03:08 +03:00
|
|
|
self.data[row] = vector
|
|
|
|
return row
|
2017-06-05 13:32:08 +03:00
|
|
|
|
|
|
|
def items(self):
|
2017-10-27 20:45:19 +03:00
|
|
|
"""Iterate over `(string key, vector)` pairs, in order.
|
|
|
|
|
|
|
|
YIELDS (tuple): A key/vector pair.
|
|
|
|
"""
|
2017-10-02 01:00:34 +03:00
|
|
|
for i, key in enumerate(self.keys):
|
|
|
|
string = self.strings[key]
|
2017-10-30 12:03:08 +03:00
|
|
|
row = self.key2row[key]
|
|
|
|
yield string, self.data[row]
|
2017-06-05 13:32:08 +03:00
|
|
|
|
|
|
|
@property
|
|
|
|
def shape(self):
|
2017-10-27 20:45:19 +03:00
|
|
|
"""Get `(rows, dims)` tuples of number of rows and number of dimensions
|
|
|
|
in the vector table.
|
|
|
|
|
|
|
|
RETURNS (tuple): A `(rows, dims)` pair.
|
|
|
|
"""
|
2017-06-05 13:32:08 +03:00
|
|
|
return self.data.shape
|
|
|
|
|
|
|
|
def most_similar(self, key):
|
2017-10-27 20:45:19 +03:00
|
|
|
# TODO: implement
|
2017-06-05 13:32:08 +03:00
|
|
|
raise NotImplementedError
|
|
|
|
|
2017-09-01 17:39:22 +03:00
|
|
|
def from_glove(self, path):
|
2017-10-27 20:45:19 +03:00
|
|
|
"""Load GloVe vectors from a directory. Assumes binary format,
|
2017-09-01 17:39:22 +03:00
|
|
|
that the vocab is in a vocab.txt, and that vectors are named
|
|
|
|
vectors.{size}.[fd].bin, e.g. vectors.128.f.bin for 128d float32
|
|
|
|
vectors, vectors.300.d.bin for 300d float64 (double) vectors, etc.
|
2017-10-27 20:45:19 +03:00
|
|
|
By default GloVe outputs 64-bit vectors.
|
|
|
|
|
|
|
|
path (unicode / Path): The path to load the GloVe vectors from.
|
|
|
|
"""
|
2017-09-01 17:39:22 +03:00
|
|
|
path = util.ensure_path(path)
|
|
|
|
for name in path.iterdir():
|
|
|
|
if name.parts[-1].startswith('vectors'):
|
|
|
|
_, dims, dtype, _2 = name.parts[-1].split('.')
|
|
|
|
self.width = int(dims)
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
raise IOError("Expected file named e.g. vectors.128.f.bin")
|
|
|
|
bin_loc = path / 'vectors.{dims}.{dtype}.bin'.format(dims=dims,
|
|
|
|
dtype=dtype)
|
|
|
|
with bin_loc.open('rb') as file_:
|
|
|
|
self.data = numpy.fromfile(file_, dtype='float64')
|
|
|
|
self.data = numpy.ascontiguousarray(self.data, dtype='float32')
|
|
|
|
n = 0
|
|
|
|
with (path / 'vocab.txt').open('r') as file_:
|
|
|
|
for line in file_:
|
|
|
|
self.add(line.strip())
|
|
|
|
n += 1
|
|
|
|
if (self.data.size % self.width) == 0:
|
|
|
|
self.data
|
|
|
|
|
2017-08-18 21:45:48 +03:00
|
|
|
def to_disk(self, path, **exclude):
|
2017-10-27 20:45:19 +03:00
|
|
|
"""Save the current state to a directory.
|
|
|
|
|
|
|
|
path (unicode / Path): A path to a directory, which will be created if
|
|
|
|
it doesn't exists. Either a string or a Path-like object.
|
|
|
|
"""
|
2017-09-16 20:45:09 +03:00
|
|
|
xp = get_array_module(self.data)
|
|
|
|
if xp is numpy:
|
2017-10-27 20:45:19 +03:00
|
|
|
save_array = lambda arr, file_: xp.save(file_, arr,
|
|
|
|
allow_pickle=False)
|
2017-09-16 20:45:09 +03:00
|
|
|
else:
|
|
|
|
save_array = lambda arr, file_: xp.save(file_, arr)
|
2017-08-18 21:45:48 +03:00
|
|
|
serializers = OrderedDict((
|
2017-09-16 20:45:09 +03:00
|
|
|
('vectors', lambda p: save_array(self.data, p.open('wb'))),
|
|
|
|
('keys', lambda p: xp.save(p.open('wb'), self.keys))
|
2017-08-18 21:45:48 +03:00
|
|
|
))
|
2017-08-19 19:42:11 +03:00
|
|
|
return util.to_disk(path, serializers, exclude)
|
2017-08-18 21:45:48 +03:00
|
|
|
|
|
|
|
def from_disk(self, path, **exclude):
|
2017-10-27 20:45:19 +03:00
|
|
|
"""Loads state from a directory. Modifies the object in place and
|
|
|
|
returns it.
|
|
|
|
|
|
|
|
path (unicode / Path): Directory path, string or Path-like object.
|
|
|
|
RETURNS (Vectors): The modified object.
|
|
|
|
"""
|
2017-08-19 19:42:11 +03:00
|
|
|
def load_keys(path):
|
2017-08-19 23:07:00 +03:00
|
|
|
if path.exists():
|
2017-10-16 21:55:00 +03:00
|
|
|
self.keys = numpy.load(path2str(path))
|
2017-08-19 23:07:00 +03:00
|
|
|
for i, key in enumerate(self.keys):
|
|
|
|
self.keys[i] = key
|
|
|
|
self.key2row[key] = i
|
2017-08-19 19:42:11 +03:00
|
|
|
|
|
|
|
def load_vectors(path):
|
2017-09-16 20:45:09 +03:00
|
|
|
xp = Model.ops.xp
|
2017-08-19 23:07:00 +03:00
|
|
|
if path.exists():
|
2017-09-16 20:45:09 +03:00
|
|
|
self.data = xp.load(path)
|
2017-08-18 21:45:48 +03:00
|
|
|
|
|
|
|
serializers = OrderedDict((
|
2017-08-19 19:42:11 +03:00
|
|
|
('keys', load_keys),
|
|
|
|
('vectors', load_vectors),
|
2017-08-18 21:45:48 +03:00
|
|
|
))
|
2017-08-19 19:42:11 +03:00
|
|
|
util.from_disk(path, serializers, exclude)
|
|
|
|
return self
|
2017-06-05 13:32:08 +03:00
|
|
|
|
|
|
|
def to_bytes(self, **exclude):
|
2017-10-27 20:45:19 +03:00
|
|
|
"""Serialize the current state to a binary string.
|
|
|
|
|
|
|
|
**exclude: Named attributes to prevent from being serialized.
|
|
|
|
RETURNS (bytes): The serialized form of the `Vectors` object.
|
|
|
|
"""
|
2017-06-05 13:32:08 +03:00
|
|
|
def serialize_weights():
|
2017-08-18 21:45:48 +03:00
|
|
|
if hasattr(self.data, 'to_bytes'):
|
|
|
|
return self.data.to_bytes()
|
2017-06-05 13:32:08 +03:00
|
|
|
else:
|
2017-08-18 21:45:48 +03:00
|
|
|
return msgpack.dumps(self.data)
|
2017-06-05 13:32:08 +03:00
|
|
|
serializers = OrderedDict((
|
2017-08-19 19:42:11 +03:00
|
|
|
('keys', lambda: msgpack.dumps(self.keys)),
|
2017-08-18 21:45:48 +03:00
|
|
|
('vectors', serialize_weights)
|
2017-06-05 13:32:08 +03:00
|
|
|
))
|
|
|
|
return util.to_bytes(serializers, exclude)
|
|
|
|
|
|
|
|
def from_bytes(self, data, **exclude):
|
2017-10-27 20:45:19 +03:00
|
|
|
"""Load state from a binary string.
|
|
|
|
|
|
|
|
data (bytes): The data to load from.
|
|
|
|
**exclude: Named attributes to prevent from being loaded.
|
|
|
|
RETURNS (Vectors): The `Vectors` object.
|
|
|
|
"""
|
2017-06-05 13:32:08 +03:00
|
|
|
def deserialize_weights(b):
|
2017-08-18 21:45:48 +03:00
|
|
|
if hasattr(self.data, 'from_bytes'):
|
|
|
|
self.data.from_bytes()
|
2017-06-05 13:32:08 +03:00
|
|
|
else:
|
2017-08-18 21:45:48 +03:00
|
|
|
self.data = msgpack.loads(b)
|
2017-06-05 13:32:08 +03:00
|
|
|
|
2017-08-19 19:42:11 +03:00
|
|
|
def load_keys(keys):
|
2017-08-19 22:27:35 +03:00
|
|
|
self.keys.resize((len(keys),))
|
2017-08-19 19:42:11 +03:00
|
|
|
for i, key in enumerate(keys):
|
|
|
|
self.keys[i] = key
|
|
|
|
self.key2row[key] = i
|
|
|
|
|
2017-06-05 13:32:08 +03:00
|
|
|
deserializers = OrderedDict((
|
2017-08-19 19:42:11 +03:00
|
|
|
('keys', lambda b: load_keys(msgpack.loads(b))),
|
2017-08-18 21:45:48 +03:00
|
|
|
('vectors', deserialize_weights)
|
2017-06-05 13:32:08 +03:00
|
|
|
))
|
2017-08-19 21:35:33 +03:00
|
|
|
util.from_bytes(data, deserializers, exclude)
|
2017-08-19 19:42:11 +03:00
|
|
|
return self
|