Merge pull request #1475 from explosion/feature/sm-vectors

Improve and simplify Vectors class
This commit is contained in:
Matthew Honnibal 2017-10-31 22:59:50 +01:00 committed by GitHub
commit 0de8d213a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 358 additions and 232 deletions

View File

@ -29,6 +29,16 @@ from . import util
VECTORS_KEY = 'spacy_pretrained_vectors'
def cosine(vec1, vec2):
xp = get_array_module(vec1)
norm1 = xp.linalg.norm(vec1)
norm2 = xp.linalg.norm(vec2)
if norm1 == 0. or norm2 == 0.:
return 0
else:
return vec1.dot(vec2) / (norm1 * norm2)
@layerize
def _flatten_add_lengths(seqs, pad=0, drop=0.):
ops = Model.ops

View File

@ -32,7 +32,6 @@ numpy.random.seed(0)
n_sents=("number of sentences", "option", "ns", int),
use_gpu=("Use GPU", "option", "g", int),
vectors=("Model to load vectors from", "option", "v"),
vectors_limit=("Truncate to N vectors (requires -v)", "option", None, int),
no_tagger=("Don't train tagger", "flag", "T", bool),
no_parser=("Don't train parser", "flag", "P", bool),
no_entities=("Don't train NER", "flag", "N", bool),
@ -41,7 +40,7 @@ numpy.random.seed(0)
meta_path=("Optional path to meta.json. All relevant properties will be "
"overwritten.", "option", "m", Path))
def train(cmd, lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
use_gpu=-1, vectors=None, vectors_limit=None, no_tagger=False,
use_gpu=-1, vectors=None, no_tagger=False,
no_parser=False, no_entities=False, gold_preproc=False,
version="0.0.0", meta_path=None):
"""
@ -95,8 +94,6 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
nlp.meta.update(meta)
if vectors:
util.load_model(vectors, vocab=nlp.vocab)
if vectors_limit is not None:
nlp.vocab.prune_vectors(vectors_limit)
for name in pipeline:
nlp.add_pipe(nlp.create_pipe(name), name=name)
optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu)

View File

@ -7,6 +7,7 @@ import spacy
import numpy
from pathlib import Path
from ..vectors import Vectors
from ..util import prints, ensure_path
@ -16,8 +17,12 @@ from ..util import prints, ensure_path
lexemes_loc=("location of JSONL-formatted lexical data", "positional",
None, Path),
vectors_loc=("optional: location of vectors data, as numpy .npz",
"positional", None, str))
def make_vocab(cmd, lang, output_dir, lexemes_loc, vectors_loc=None):
"positional", None, str),
prune_vectors=("optional: number of vectors to prune to.",
"option", "V", int)
)
def make_vocab(cmd, lang, output_dir, lexemes_loc,
vectors_loc=None, prune_vectors=-1):
"""Compile a vocabulary from a lexicon jsonl file and word vectors."""
if not lexemes_loc.exists():
prints(lexemes_loc, title="Can't find lexical data", exits=1)
@ -26,7 +31,6 @@ def make_vocab(cmd, lang, output_dir, lexemes_loc, vectors_loc=None):
for word in nlp.vocab:
word.rank = 0
lex_added = 0
vec_added = 0
with lexemes_loc.open() as file_:
for line in file_:
if line.strip():
@ -39,16 +43,18 @@ def make_vocab(cmd, lang, output_dir, lexemes_loc, vectors_loc=None):
assert lex.rank == attrs['id']
lex_added += 1
if vectors_loc is not None:
vector_data = numpy.load(open(vectors_loc, 'rb'))
nlp.vocab.clear_vectors(width=vector_data.shape[1])
vector_data = numpy.load(vectors_loc.open('rb'))
nlp.vocab.vectors = Vectors(data=vector_data)
for word in nlp.vocab:
if word.rank:
nlp.vocab.vectors.add(word.orth_, row=word.rank,
vector=vector_data[word.rank])
vec_added += 1
nlp.vocab.vectors.add(word.orth, row=word.rank)
if prune_vectors >= 1:
remap = nlp.vocab.prune_vectors(prune_vectors)
if not output_dir.exists():
output_dir.mkdir()
nlp.to_disk(output_dir)
vec_added = len(nlp.vocab.vectors)
prints("{} entries, {} vectors".format(lex_added, vec_added), output_dir,
title="Sucessfully compiled vocab and vectors, and saved model")
return nlp

View File

@ -208,8 +208,8 @@ def test_doc_api_right_edge(en_tokenizer):
def test_doc_api_has_vector():
vocab = Vocab()
vocab.clear_vectors(2)
vocab.vectors.add('kitten', vector=numpy.asarray([0., 2.], dtype='f'))
vocab.reset_vectors(width=2)
vocab.set_vector('kitten', vector=numpy.asarray([0., 2.], dtype='f'))
doc = Doc(vocab, words=['kitten'])
assert doc.has_vector

View File

@ -72,9 +72,9 @@ def test_doc_token_api_is_properties(en_vocab):
def test_doc_token_api_vectors():
vocab = Vocab()
vocab.clear_vectors(2)
vocab.vectors.add('apples', vector=numpy.asarray([0., 2.], dtype='f'))
vocab.vectors.add('oranges', vector=numpy.asarray([0., 1.], dtype='f'))
vocab.reset_vectors(width=2)
vocab.set_vector('apples', vector=numpy.asarray([0., 2.], dtype='f'))
vocab.set_vector('oranges', vector=numpy.asarray([0., 1.], dtype='f'))
doc = Doc(vocab, words=['apples', 'oranges', 'oov'])
assert doc.has_vector

View File

@ -79,9 +79,9 @@ def add_vecs_to_vocab(vocab, vectors):
"""Add list of vector tuples to given vocab. All vectors need to have the
same length. Format: [("text", [1, 2, 3])]"""
length = len(vectors[0][1])
vocab.clear_vectors(length)
vocab.reset_vectors(width=length)
for word, vec in vectors:
vocab.set_vector(word, vec)
vocab.set_vector(word, vector=vec)
return vocab

View File

@ -35,20 +35,18 @@ def vocab(en_vocab, vectors):
def test_init_vectors_with_data(strings, data):
v = Vectors(strings, data=data)
v = Vectors(data=data)
assert v.shape == data.shape
def test_init_vectors_with_width(strings):
v = Vectors(strings, width=3)
for string in strings:
v.add(string)
def test_init_vectors_with_shape(strings):
v = Vectors(shape=(len(strings), 3))
assert v.shape == (len(strings), 3)
def test_get_vector(strings, data):
v = Vectors(strings, data=data)
for string in strings:
v.add(string)
v = Vectors(data=data)
for i, string in enumerate(strings):
v.add(string, row=i)
assert list(v[strings[0]]) == list(data[0])
assert list(v[strings[0]]) != list(data[1])
assert list(v[strings[1]]) != list(data[0])
@ -56,9 +54,9 @@ def test_get_vector(strings, data):
def test_set_vector(strings, data):
orig = data.copy()
v = Vectors(strings, data=data)
for string in strings:
v.add(string)
v = Vectors(data=data)
for i, string in enumerate(strings):
v.add(string, row=i)
assert list(v[strings[0]]) == list(orig[0])
assert list(v[strings[0]]) != list(orig[1])
v[strings[0]] = data[1]
@ -66,7 +64,6 @@ def test_set_vector(strings, data):
assert list(v[strings[0]]) != list(orig[0])
@pytest.fixture()
def tokenizer_v(vocab):
return Tokenizer(vocab, {}, None, None, None)

View File

@ -2,14 +2,39 @@
from __future__ import unicode_literals
import numpy
import pytest
from numpy.testing import assert_allclose
from ...vocab import Vocab
from ..._ml import cosine
@pytest.mark.xfail
@pytest.mark.parametrize('text', ["Hello"])
def test_vocab_add_vector(en_vocab, text):
en_vocab.resize_vectors(10)
lex = en_vocab[text]
lex.vector = numpy.ndarray((10,), dtype='float32')
lex = en_vocab[text]
assert lex.vector.shape == (10,)
def test_vocab_add_vector():
vocab = Vocab()
data = numpy.ndarray((5,3), dtype='f')
data[0] = 1.
data[1] = 2.
vocab.set_vector(u'cat', data[0])
vocab.set_vector(u'dog', data[1])
cat = vocab[u'cat']
assert list(cat.vector) == [1., 1., 1.]
dog = vocab[u'dog']
assert list(dog.vector) == [2., 2., 2.]
def test_vocab_prune_vectors():
vocab = Vocab()
_ = vocab[u'cat']
_ = vocab[u'dog']
_ = vocab[u'kitten']
data = numpy.ndarray((5,3), dtype='f')
data[0] = 1.
data[1] = 2.
data[2] = 1.1
vocab.set_vector(u'cat', data[0])
vocab.set_vector(u'dog', data[1])
vocab.set_vector(u'kitten', data[2])
remap = vocab.prune_vectors(2)
assert list(remap.keys()) == [u'kitten']
neighbour, similarity = list(remap.values())[0]
assert neighbour == u'cat', remap
assert_allclose(similarity, cosine(data[0], data[2]), atol=1e-6)

View File

@ -15,6 +15,12 @@ from .compat import basestring_, path2str
from . import util
def unpickle_vectors(keys_and_rows, data):
vectors = Vectors(data=data)
for key, row in keys_and_rows:
vectors.add(key, row=row)
cdef class Vectors:
"""Store, save and load word vectors.
@ -23,140 +29,35 @@ cdef class Vectors:
(for GPU vectors). `vectors.key2row` is a dictionary mapping word hashes to
rows in the vectors.data table.
Multiple keys can be mapped to the same vector, so len(keys) may be greater
(but not smaller) than data.shape[0].
Multiple keys can be mapped to the same vector, and not all of the rows in
the table need to be assigned --- so len(list(vectors.keys())) may be
greater or smaller than vectors.shape[0].
"""
cdef public object data
cdef readonly StringStore strings
cdef public object key2row
cdef public object keys
cdef public int _i_key
cdef public int _i_vec
cdef public object _unset
def __init__(self, strings, width=0, data=None):
"""Create a new vector store. To keep the vector table empty, pass
`width=0`. You can also create the vector table and add vectors one by
one, or set the vector values directly on initialisation.
strings (StringStore or list): List of strings or StringStore that maps
strings to hash values, and vice versa.
width (int): Number of dimensions.
def __init__(self, *, shape=None, data=None, keys=None):
"""Create a new vector store.
shape (tuple): Size of the table, as (# entries, # columns)
data (numpy.ndarray): The vector data.
RETURNS (Vectors): The newly created object.
"""
if isinstance(strings, StringStore):
self.strings = strings
if data is None:
if shape is None:
shape = (0,0)
data = numpy.zeros(shape, dtype='f')
self.data = data
self.key2row = OrderedDict()
if self.data is not None:
self._unset = set(range(self.data.shape[0]))
else:
self.strings = StringStore()
for string in strings:
self.strings.add(string)
if data is not None:
self.data = numpy.asarray(data, dtype='f')
else:
self.data = numpy.zeros((len(self.strings), width), dtype='f')
self._i_key = 0
self._i_vec = 0
self.key2row = {}
self.keys = numpy.zeros((self.data.shape[0],), dtype='uint64')
if data is not None:
for i, string in enumerate(self.strings):
if i >= self.data.shape[0]:
break
self.add(self.strings[string], vector=self.data[i])
def __reduce__(self):
return (Vectors, (self.strings, self.data))
def __getitem__(self, key):
"""Get a vector by key. If key is a string, it is hashed to an integer
ID using the vectors.strings table. If the integer key is not found in
the table, a KeyError is raised.
key (unicode / int): The key to get the vector for.
RETURNS (numpy.ndarray): The vector for the key.
"""
if isinstance(key, basestring):
key = self.strings[key]
i = self.key2row[key]
if i is None:
raise KeyError(key)
else:
return self.data[i]
def __setitem__(self, key, vector):
"""Set a vector for the given key. If key is a string, it is hashed
to an integer ID using the vectors.strings table.
key (unicode / int): The key to set the vector for.
vector (numpy.ndarray): The vector to set.
"""
if isinstance(key, basestring):
key = self.strings.add(key)
i = self.key2row[key]
self.data[i] = vector
def __iter__(self):
"""Yield vectors from the table.
YIELDS (numpy.ndarray): A vector.
"""
yield from self.data
def __len__(self):
"""Return the number of vectors that have been assigned.
RETURNS (int): The number of vectors in the data.
"""
return self._i_vec
def __contains__(self, key):
"""Check whether a key has a vector entry in the table.
key (unicode / int): The key to check.
RETURNS (bool): Whether the key has a vector entry.
"""
if isinstance(key, basestring_):
key = self.strings[key]
return key in self.key2row
def add(self, key, *, vector=None, row=None):
"""Add a key to the table. Keys can be mapped to an existing vector
by setting `row`, or a new vector can be added.
key (unicode / int): The key to add.
vector (numpy.ndarray / None): A vector to add for the key.
row (int / None): The row-number of a vector to map the key to.
"""
if isinstance(key, basestring_):
key = self.strings.add(key)
if row is None and key in self.key2row:
row = self.key2row[key]
elif row is None:
row = self._i_vec
self._i_vec += 1
if row >= self.data.shape[0]:
self.data.resize((row*2, self.data.shape[1]))
if key not in self.key2row:
if self._i_key >= self.keys.shape[0]:
self.keys.resize((self._i_key*2,))
self.keys[self._i_key] = key
self._i_key += 1
self.key2row[key] = row
if vector is not None:
self.data[row] = vector
return row
def items(self):
"""Iterate over `(string key, vector)` pairs, in order.
YIELDS (tuple): A key/vector pair.
"""
for i, key in enumerate(self.keys):
string = self.strings[key]
row = self.key2row[key]
yield string, self.data[row]
self._unset = set()
if keys is not None:
for i, key in enumerate(keys):
self.add(key, row=i)
@property
def shape(self):
"""Get `(rows, dims)` tuples of number of rows and number of dimensions
@ -166,9 +67,184 @@ cdef class Vectors:
"""
return self.data.shape
def most_similar(self, key):
# TODO: implement
raise NotImplementedError
@property
def size(self):
"""Return rows*dims"""
return self.data.shape[0] * self.data.shape[1]
@property
def is_full(self):
"""Returns True if no keys are available for new keys."""
return len(self._unset) == 0
@property
def n_keys(self):
"""Returns True if no keys are available for new keys."""
return len(self.key2row)
def __reduce__(self):
keys_and_rows = self.key2row.items()
return (unpickle_vectors, (keys_and_rows, self.data))
def __getitem__(self, key):
"""Get a vector by key. If the key is not found, a KeyError is raised.
key (int): The key to get the vector for.
RETURNS (ndarray): The vector for the key.
"""
i = self.key2row[key]
if i is None:
raise KeyError(key)
else:
return self.data[i]
def __setitem__(self, key, vector):
"""Set a vector for the given key.
key (int): The key to set the vector for.
vector (numpy.ndarray): The vector to set.
"""
i = self.key2row[key]
self.data[i] = vector
if i in self._unset:
self._unset.remove(i)
def __iter__(self):
"""Yield vectors from the table.
YIELDS (ndarray): A vector.
"""
yield from self.key2row
def __len__(self):
"""Return the number of vectors in the table.
RETURNS (int): The number of vectors in the data.
"""
return self.data.shape[0]
def __contains__(self, key):
"""Check whether a key has been mapped to a vector entry in the table.
key (int): The key to check.
RETURNS (bool): Whether the key has a vector entry.
"""
return key in self.key2row
def resize(self, shape, inplace=False):
'''Resize the underlying vectors array. If inplace=True, the memory
is reallocated. This may cause other references to the data to become
invalid, so only use inplace=True if you're sure that's what you want.
If the number of vectors is reduced, keys mapped to rows that have been
deleted are removed. These removed items are returned as a list of
(key, row) tuples.
'''
if inplace:
self.data.resize(shape, refcheck=False)
else:
xp = get_array_module(self.data)
self.data = xp.resize(self.data, shape)
filled = {row for row in self.key2row.values()}
self._unset = {row for row in range(shape[0]) if row not in filled}
removed_items = []
for key, row in dict(self.key2row.items()):
if row >= shape[0]:
self.key2row.pop(key)
removed_items.append((key, row))
return removed_items
def keys(self):
'''Iterate over the keys in the table.'''
yield from self.key2row.keys()
def values(self):
'''Iterate over vectors that have been assigned to at least one key.
Note that some vectors may be unassigned, so the number of vectors
returned may be less than the length of the vectors table.'''
for row, vector in enumerate(range(self.data.shape[0])):
if row not in self._unset:
yield vector
def items(self):
"""Iterate over `(key, vector)` pairs.
YIELDS (tuple): A key/vector pair.
"""
for key, row in self.key2row.items():
yield key, self.data[row]
def get_keys(self, rows):
xp = get_array_module(self.data)
row2key = {row: key for key, row in self.key2row.items()}
keys = xp.asarray([row2key[row] for row in rows],
dtype='uint64')
return keys
def get_rows(self, keys):
xp = get_array_module(self.data)
k2r = self.key2row
return xp.asarray([k2r.get(key, -1) for key in keys], dtype='i')
def add(self, key, *, vector=None, row=None):
"""Add a key to the table. Keys can be mapped to an existing vector
by setting `row`, or a new vector can be added.
key (unicode / int): The key to add.
vector (numpy.ndarray / None): A vector to add for the key.
row (int / None): The row-number of a vector to map the key to.
"""
if row is None and key in self.key2row:
row = self.key2row[key]
elif row is None:
if self.is_full:
raise ValueError("Cannot add new key to vectors -- full")
row = min(self._unset)
self.key2row[key] = row
if vector is not None:
self.data[row] = vector
if row in self._unset:
self._unset.remove(row)
return row
def most_similar(self, queries, *, return_scores=False, return_rows=False,
batch_size=1024):
'''For each of the given vectors, find the single entry most similar
to it, by cosine.
Queries are by vector. Results are returned as an array of keys,
or a tuple of (keys, scores) if return_scores=True. If `queries` is
large, the calculations are performed in chunks, to avoid consuming
too much memory. You can set the `batch_size` to control the size/space
trade-off during the calculations.
'''
xp = get_array_module(self.data)
vectors = self.data / xp.linalg.norm(self.data, axis=1, keepdims=True)
best_rows = xp.zeros((queries.shape[0],), dtype='i')
scores = xp.zeros((queries.shape[0],), dtype='f')
# Work in batches, to avoid memory problems.
for i in range(0, queries.shape[0], batch_size):
batch = queries[i : i+batch_size]
batch /= xp.linalg.norm(batch, axis=1, keepdims=True)
# batch e.g. (1024, 300)
# vectors e.g. (10000, 300)
# sims e.g. (1024, 10000)
sims = xp.dot(batch, vectors.T)
best_rows[i:i+batch_size] = sims.argmax(axis=1)
scores[i:i+batch_size] = sims.max(axis=1)
keys = self.get_keys(best_rows)
if return_rows and return_scores:
return (keys, best_rows, scores)
elif return_rows:
return (keys, best_rows)
elif return_scores:
return (keys, scores)
else:
return keys
def from_glove(self, path):
"""Load GloVe vectors from a directory. Assumes binary format,
@ -178,27 +254,33 @@ cdef class Vectors:
By default GloVe outputs 64-bit vectors.
path (unicode / Path): The path to load the GloVe vectors from.
RETURNS: A StringStore object, holding the key-to-string mapping.
"""
path = util.ensure_path(path)
width = None
for name in path.iterdir():
if name.parts[-1].startswith('vectors'):
_, dims, dtype, _2 = name.parts[-1].split('.')
self.width = int(dims)
width = int(dims)
break
else:
raise IOError("Expected file named e.g. vectors.128.f.bin")
bin_loc = path / 'vectors.{dims}.{dtype}.bin'.format(dims=dims,
dtype=dtype)
xp = get_array_module(self.data)
self.data = None
with bin_loc.open('rb') as file_:
self.data = numpy.fromfile(file_, dtype='float64')
self.data = numpy.ascontiguousarray(self.data, dtype='float32')
self.data = xp.fromfile(file_, dtype=dtype)
if dtype != 'float32':
self.data = xp.ascontiguousarray(self.data, dtype='float32')
n = 0
strings = StringStore()
with (path / 'vocab.txt').open('r') as file_:
for line in file_:
self.add(line.strip())
n += 1
if (self.data.size % self.width) == 0:
self.data
for i, line in enumerate(file_):
key = strings.add(line.strip())
self.add(key, row=i)
return strings
def to_disk(self, path, **exclude):
"""Save the current state to a directory.
@ -214,7 +296,7 @@ cdef class Vectors:
save_array = lambda arr, file_: xp.save(file_, arr)
serializers = OrderedDict((
('vectors', lambda p: save_array(self.data, p.open('wb'))),
('keys', lambda p: xp.save(p.open('wb'), self.keys))
('key2row', lambda p: msgpack.dump(self.key2row, p.open('wb')))
))
return util.to_disk(path, serializers, exclude)
@ -225,12 +307,18 @@ cdef class Vectors:
path (unicode / Path): Directory path, string or Path-like object.
RETURNS (Vectors): The modified object.
"""
def load_key2row(path):
if path.exists():
self.key2row = msgpack.load(path.open('rb'))
for key, row in self.key2row.items():
if row in self._unset:
self._unset.remove(row)
def load_keys(path):
if path.exists():
self.keys = numpy.load(path2str(path))
for i, key in enumerate(self.keys):
self.keys[i] = key
self.key2row[key] = i
keys = numpy.load(str(path))
for i, key in enumerate(keys):
self.add(key, row=i)
def load_vectors(path):
xp = Model.ops.xp
@ -238,6 +326,7 @@ cdef class Vectors:
self.data = xp.load(path)
serializers = OrderedDict((
('key2row', load_key2row),
('keys', load_keys),
('vectors', load_vectors),
))
@ -256,7 +345,7 @@ cdef class Vectors:
else:
return msgpack.dumps(self.data)
serializers = OrderedDict((
('keys', lambda: msgpack.dumps(self.keys)),
('key2row', lambda: msgpack.dumps(self.key2row)),
('vectors', serialize_weights)
))
return util.to_bytes(serializers, exclude)
@ -274,14 +363,8 @@ cdef class Vectors:
else:
self.data = msgpack.loads(b)
def load_keys(keys):
self.keys.resize((len(keys),))
for i, key in enumerate(keys):
self.keys[i] = key
self.key2row[key] = i
deserializers = OrderedDict((
('keys', lambda b: load_keys(msgpack.loads(b))),
('key2row', lambda b: self.key2row.update(msgpack.loads(b))),
('vectors', deserialize_weights)
))
util.from_bytes(data, deserializers, exclude)

View File

@ -55,7 +55,7 @@ cdef class Vocab:
_ = self[string]
self.lex_attr_getters = lex_attr_getters
self.morphology = Morphology(self.strings, tag_map, lemmatizer)
self.vectors = Vectors(self.strings, width=0)
self.vectors = Vectors()
property lang:
def __get__(self):
@ -192,10 +192,11 @@ cdef class Vocab:
YIELDS (Lexeme): An entry in the vocabulary.
"""
cdef attr_t orth
cdef attr_t key
cdef size_t addr
for orth, addr in self._by_orth.items():
yield Lexeme(self, orth)
for key, addr in self._by_orth.items():
lex = Lexeme(self, key)
yield lex
def __getitem__(self, id_or_string):
"""Retrieve a lexeme, given an int ID or a unicode string. If a
@ -213,7 +214,7 @@ cdef class Vocab:
>>> assert nlp.vocab[apple] == nlp.vocab[u'apple']
"""
cdef attr_t orth
if type(id_or_string) == unicode:
if isinstance(id_or_string, unicode):
orth = self.strings.add(id_or_string)
else:
orth = id_or_string
@ -240,15 +241,19 @@ cdef class Vocab:
def vectors_length(self):
return self.vectors.data.shape[1]
def clear_vectors(self, width=None):
def reset_vectors(self, *, width=None, shape=None):
"""Drop the current vector table. Because all vectors must be the same
width, you have to call this to change the size of the vectors.
"""
if width is None:
width = self.vectors.data.shape[1]
self.vectors = Vectors(self.strings, width=width)
if width is not None and shape is not None:
raise ValueError("Only one of width and shape can be specified")
elif shape is not None:
self.vectors = Vectors(shape=shape)
else:
width = width if width is not None else self.vectors.data.shape[1]
self.vectors = Vectors(shape=(self.vectors.shape[0], width))
def prune_vectors(self, nr_row, batch_size=8):
def prune_vectors(self, nr_row, batch_size=1024):
"""Reduce the current vector table to `nr_row` unique entries. Words
mapped to the discarded vectors will be remapped to the closest vector
among those remaining.
@ -274,36 +279,31 @@ cdef class Vocab:
two words.
"""
xp = get_array_module(self.vectors.data)
# Work in batches, to avoid memory problems.
keep = self.vectors.data[:nr_row]
keep_keys = [key for key, row in self.vectors.key2row.items() if row < nr_row]
toss = self.vectors.data[nr_row:]
# Normalize the vectors, so cosine similarity is just dot product.
# Note we can't modify the ones we're keeping in-place...
keep = keep / (xp.linalg.norm(keep, axis=1, keepdims=True)+1e-8)
keep = xp.ascontiguousarray(keep.T)
neighbours = xp.zeros((toss.shape[0],), dtype='i')
scores = xp.zeros((toss.shape[0],), dtype='f')
for i in range(0, toss.shape[0], batch_size):
batch = toss[i : i+batch_size]
batch /= xp.linalg.norm(batch, axis=1, keepdims=True)+1e-8
sims = xp.dot(batch, keep)
matches = sims.argmax(axis=1)
neighbours[i:i+batch_size] = matches
scores[i:i+batch_size] = sims.max(axis=1)
for lex in self:
# If we're losing the vector for this word, map it to the nearest
# vector we're keeping.
if lex.rank >= nr_row:
lex.rank = neighbours[lex.rank-nr_row]
self.vectors.add(lex.orth, row=lex.rank)
for key in self.vectors.keys:
row = self.vectors.key2row[key]
if row >= nr_row:
self.vectors.key2row[key] = neighbours[row-nr_row]
# Make copy, to encourage the original table to be garbage collected.
self.vectors.data = xp.ascontiguousarray(self.vectors.data[:nr_row])
# TODO: return new mapping
# Make prob negative so it sorts by rank ascending
# (key2row contains the rank)
priority = [(-lex.prob, self.vectors.key2row[lex.orth], lex.orth)
for lex in self if lex.orth in self.vectors.key2row]
priority.sort()
indices = xp.asarray([i for (prob, i, key) in priority], dtype='i')
keys = xp.asarray([key for (prob, i, key) in priority], dtype='uint64')
keep = xp.ascontiguousarray(self.vectors.data[indices[:nr_row]])
toss = xp.ascontiguousarray(self.vectors.data[indices[nr_row:]])
self.vectors = Vectors(data=keep, keys=keys)
syn_keys, syn_rows, scores = self.vectors.most_similar(toss,
return_rows=True, return_scores=True)
remap = {}
for i, key in enumerate(keys[nr_row:]):
self.vectors.add(key, row=syn_rows[i])
word = self.strings[key]
synonym = self.strings[syn_keys[i]]
score = scores[i]
remap[word] = (synonym, score)
link_vectors_to_models(self)
return remap
def get_vector(self, orth):
"""Retrieve a vector for a word in the vocabulary. Words can be looked
@ -325,8 +325,16 @@ cdef class Vocab:
"""Set a vector for a word in the vocabulary. Words can be referenced
by string or int ID.
"""
if not isinstance(orth, basestring_):
orth = self.strings[orth]
if isinstance(orth, basestring_):
orth = self.strings.add(orth)
if self.vectors.is_full and orth not in self.vectors:
new_rows = max(100, int(self.vectors.shape[0]*1.3))
if self.vectors.shape[1] == 0:
width = vector.size
else:
width = self.vectors.shape[1]
self.vectors.resize((new_rows, width))
self.vectors.add(orth, vector=vector)
self.vectors.add(orth, vector=vector)
def has_vector(self, orth):