WIP on refactoring and fixing vectors

This commit is contained in:
Matthew Honnibal 2017-10-31 02:00:26 +01:00
parent 256c7dac5a
commit 368fdb389a
5 changed files with 147 additions and 56 deletions

View File

@ -29,6 +29,12 @@ from . import util
VECTORS_KEY = 'spacy_pretrained_vectors' VECTORS_KEY = 'spacy_pretrained_vectors'
def cosine(vec1, vec2):
norm1 = (vec1**2).sum() ** 0.5
norm2 = (vec2**2).sum() ** 0.5
return vec1.dot(vec2) / (norm1 * norm2)
@layerize @layerize
def _flatten_add_lengths(seqs, pad=0, drop=0.): def _flatten_add_lengths(seqs, pad=0, drop=0.):
ops = Model.ops ops = Model.ops
@ -198,11 +204,11 @@ class PrecomputableAffine(Model):
def link_vectors_to_models(vocab): def link_vectors_to_models(vocab):
vectors = vocab.vectors vectors = vocab.vectors
ops = Model.ops ops = Model.ops
for word in vocab: #for word in vocab:
if word.orth in vectors.key2row: # if word.orth in vectors.key2row:
word.rank = vectors.key2row[word.orth] # word.rank = vectors.key2row[word.orth]
else: # else:
word.rank = 0 # word.rank = 0
data = ops.asarray(vectors.data) data = ops.asarray(vectors.data)
# Set an entry here, so that vectors are accessed by StaticVectors # Set an entry here, so that vectors are accessed by StaticVectors
# (unideal, I know) # (unideal, I know)

View File

@ -94,6 +94,11 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
nlp.meta.update(meta) nlp.meta.update(meta)
if vectors: if vectors:
util.load_model(vectors, vocab=nlp.vocab) util.load_model(vectors, vocab=nlp.vocab)
if vectors_limit is not None:
remap = nlp.vocab.prune_vectors(vectors_limit)
print('remap', len(remap))
for key, (value, sim) in remap.items():
print(repr(key), repr(value), sim)
for name in pipeline: for name in pipeline:
nlp.add_pipe(nlp.create_pipe(name), name=name) nlp.add_pipe(nlp.create_pipe(name), name=name)
optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu) optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu)

View File

@ -3,13 +3,41 @@ from __future__ import unicode_literals
import numpy import numpy
import pytest import pytest
from ...vocab import Vocab
from ..._ml import cosine
@pytest.mark.xfail def test_vocab_add_vector():
@pytest.mark.parametrize('text', ["Hello"]) vocab = Vocab()
def test_vocab_add_vector(en_vocab, text): data = numpy.ndarray((5,3), dtype='f')
en_vocab.resize_vectors(10) data[0] = 1.
lex = en_vocab[text] data[1] = 2.
lex.vector = numpy.ndarray((10,), dtype='float32') vocab.set_vector(u'cat', data[0])
lex = en_vocab[text] vocab.set_vector(u'dog', data[1])
assert lex.vector.shape == (10,) cat = vocab[u'cat']
assert list(cat.vector) == [1., 1., 1.]
dog = vocab[u'dog']
assert list(dog.vector) == [2., 2., 2.]
for lex in vocab:
print(lex.orth_)
def test_vocab_prune_vectors():
vocab = Vocab()
_ = vocab[u'cat']
_ = vocab[u'dog']
_ = vocab[u'kitten']
print(list(vocab.strings))
data = numpy.ndarray((5,3), dtype='f')
data[0] = 1.
data[1] = 2.
data[2] = 1.1
vocab.set_vector(u'cat', data[0])
vocab.set_vector(u'dog', data[1])
vocab.set_vector(u'kitten', data[2])
for lex in vocab:
print(lex.orth_)
remap = vocab.prune_vectors(2)
assert remap == {u'kitten': (u'cat', cosine(data[0], data[2]))}
#print(remap)

View File

@ -27,8 +27,7 @@ cdef class Vectors:
cdef public object data cdef public object data
cdef readonly StringStore strings cdef readonly StringStore strings
cdef public object key2row cdef public object key2row
cdef public object keys cdef public int _i_vec
cdef public int i
def __init__(self, strings, width=0, data=None): def __init__(self, strings, width=0, data=None):
"""Create a new vector store. To keep the vector table empty, pass """Create a new vector store. To keep the vector table empty, pass
@ -51,13 +50,13 @@ cdef class Vectors:
self.data = numpy.asarray(data, dtype='f') self.data = numpy.asarray(data, dtype='f')
else: else:
self.data = numpy.zeros((len(self.strings), width), dtype='f') self.data = numpy.zeros((len(self.strings), width), dtype='f')
self.i = 0 self._i_vec = 0
self.key2row = {} self.key2row = {}
self.keys = numpy.zeros((self.data.shape[0],), dtype='uint64') if data is not None:
for i, string in enumerate(self.strings): for i, string in enumerate(self.strings):
if i >= self.data.shape[0]: if i >= self.data.shape[0]:
break break
self.add(self.strings[string], self.data[i]) self.add(self.strings[string], vector=self.data[i])
def __reduce__(self): def __reduce__(self):
return (Vectors, (self.strings, self.data)) return (Vectors, (self.strings, self.data))
@ -122,16 +121,15 @@ cdef class Vectors:
""" """
if isinstance(key, basestring_): if isinstance(key, basestring_):
key = self.strings.add(key) key = self.strings.add(key)
if key not in self.key2row: if row is None and key in self.key2row:
i = self.i row = self.key2row[key]
if i >= self.keys.shape[0]: elif row is None:
self.keys.resize((self.keys.shape[0]*2,)) row = self._i_vec
self.data.resize((self.data.shape[0]*2, self.data.shape[1])) self._i_vec += 1
self.key2row[key] = self.i if row >= self.data.shape[0]:
self.keys[self.i] = key self.data.resize((row*2, self.data.shape[1]))
self.i += 1
else: self.key2row[key] = row
i = self.key2row[key]
if vector is not None: if vector is not None:
self.data[i] = vector self.data[i] = vector
return i return i
@ -141,9 +139,9 @@ cdef class Vectors:
YIELDS (tuple): A key/vector pair. YIELDS (tuple): A key/vector pair.
""" """
for i, key in enumerate(self.keys): for key, row in self.key2row.items():
string = self.strings[key] string = self.strings[key]
yield string, self.data[i] yield string, self.data[row]
@property @property
def shape(self): def shape(self):
@ -202,7 +200,7 @@ cdef class Vectors:
save_array = lambda arr, file_: xp.save(file_, arr) save_array = lambda arr, file_: xp.save(file_, arr)
serializers = OrderedDict(( serializers = OrderedDict((
('vectors', lambda p: save_array(self.data, p.open('wb'))), ('vectors', lambda p: save_array(self.data, p.open('wb'))),
('keys', lambda p: xp.save(p.open('wb'), self.keys)) ('key2row', lambda p: msgpack.dump(self.key2row, p.open('wb')))
)) ))
return util.to_disk(path, serializers, exclude) return util.to_disk(path, serializers, exclude)
@ -215,10 +213,7 @@ cdef class Vectors:
""" """
def load_keys(path): def load_keys(path):
if path.exists(): if path.exists():
self.keys = numpy.load(path2str(path)) self.key2row = msgpack.load(path.open('rb'))
for i, key in enumerate(self.keys):
self.keys[i] = key
self.key2row[key] = i
def load_vectors(path): def load_vectors(path):
xp = Model.ops.xp xp = Model.ops.xp
@ -226,7 +221,7 @@ cdef class Vectors:
self.data = xp.load(path) self.data = xp.load(path)
serializers = OrderedDict(( serializers = OrderedDict((
('keys', load_keys), ('key2row', load_keys),
('vectors', load_vectors), ('vectors', load_vectors),
)) ))
util.from_disk(path, serializers, exclude) util.from_disk(path, serializers, exclude)
@ -244,7 +239,7 @@ cdef class Vectors:
else: else:
return msgpack.dumps(self.data) return msgpack.dumps(self.data)
serializers = OrderedDict(( serializers = OrderedDict((
('keys', lambda: msgpack.dumps(self.keys)), ('key2row', lambda: msgpack.dumps(self.key2row)),
('vectors', serialize_weights) ('vectors', serialize_weights)
)) ))
return util.to_bytes(serializers, exclude) return util.to_bytes(serializers, exclude)
@ -262,14 +257,8 @@ cdef class Vectors:
else: else:
self.data = msgpack.loads(b) self.data = msgpack.loads(b)
def load_keys(keys):
self.keys.resize((len(keys),))
for i, key in enumerate(keys):
self.keys[i] = key
self.key2row[key] = i
deserializers = OrderedDict(( deserializers = OrderedDict((
('keys', lambda b: load_keys(msgpack.loads(b))), ('key2row', lambda b: self.key2row.update(msgpack.loads(b))),
('vectors', deserialize_weights) ('vectors', deserialize_weights)
)) ))
util.from_bytes(data, deserializers, exclude) util.from_bytes(data, deserializers, exclude)

View File

@ -190,10 +190,11 @@ cdef class Vocab:
YIELDS (Lexeme): An entry in the vocabulary. YIELDS (Lexeme): An entry in the vocabulary.
""" """
cdef attr_t orth cdef attr_t key
cdef size_t addr cdef size_t addr
for orth, addr in self._by_orth.items(): for key, addr in self._by_orth.items():
yield Lexeme(self, orth) lex = Lexeme(self, key)
yield lex
def __getitem__(self, id_or_string): def __getitem__(self, id_or_string):
"""Retrieve a lexeme, given an int ID or a unicode string. If a """Retrieve a lexeme, given an int ID or a unicode string. If a
@ -211,7 +212,7 @@ cdef class Vocab:
>>> assert nlp.vocab[apple] == nlp.vocab[u'apple'] >>> assert nlp.vocab[apple] == nlp.vocab[u'apple']
""" """
cdef attr_t orth cdef attr_t orth
if type(id_or_string) == unicode: if isinstance(id_or_string, unicode):
orth = self.strings.add(id_or_string) orth = self.strings.add(id_or_string)
else: else:
orth = id_or_string orth = id_or_string
@ -242,9 +243,69 @@ cdef class Vocab:
"""Drop the current vector table. Because all vectors must be the same """Drop the current vector table. Because all vectors must be the same
width, you have to call this to change the size of the vectors. width, you have to call this to change the size of the vectors.
""" """
if new_dim is None: if width is None:
new_dim = self.vectors.data.shape[1] width = self.vectors.data.shape[1]
self.vectors = Vectors(self.strings, width=new_dim) self.vectors = Vectors(self.strings, width=width)
def prune_vectors(self, nr_row, batch_size=8):
"""Reduce the current vector table to `nr_row` unique entries. Words
mapped to the discarded vectors will be remapped to the closest vector
among those remaining.
For example, suppose the original table had vectors for the words:
['sat', 'cat', 'feline', 'reclined']. If we prune the vector table to,
two rows, we would discard the vectors for 'feline' and 'reclined'.
These words would then be remapped to the closest remaining vector
-- so "feline" would have the same vector as "cat", and "reclined"
would have the same vector as "sat".
The similarities are judged by cosine. The original vectors may
be large, so the cosines are calculated in minibatches, to reduce
memory usage.
nr_row (int): The number of rows to keep in the vector table.
batch_size (int): Batch of vectors for calculating the similarities.
Larger batch sizes might be faster, while temporarily requiring
more memory.
RETURNS (dict): A dictionary keyed by removed words mapped to
`(string, score)` tuples, where `string` is the entry the removed
word was mapped to, and `score` the similarity score between the
two words.
"""
xp = get_array_module(self.vectors.data)
# Work in batches, to avoid memory problems.
keep = self.vectors.data[:nr_row]
keep_keys = [key for key, row in self.vectors.key2row.items() if row < nr_row]
toss = self.vectors.data[nr_row:]
# Normalize the vectors, so cosine similarity is just dot product.
# Note we can't modify the ones we're keeping in-place...
keep = keep / (xp.linalg.norm(keep, axis=1, keepdims=True)+1e-8)
keep = xp.ascontiguousarray(keep.T)
neighbours = xp.zeros((toss.shape[0],), dtype='i')
scores = xp.zeros((toss.shape[0],), dtype='f')
for i in range(0, toss.shape[0]//2, batch_size):
batch = toss[i : i+batch_size]
batch /= xp.linalg.norm(batch, axis=1, keepdims=True)+1e-8
sims = xp.dot(batch, keep)
matches = sims.argmax(axis=1)
neighbours[i:i+batch_size] = matches
scores[i:i+batch_size] = sims.max(axis=1)
i2k = {i: key for key, i in self.vectors.key2row.items()}
remap = {}
for lex in list(self):
# If we're losing the vector for this word, map it to the nearest
# vector we're keeping.
if lex.rank >= nr_row:
lex.rank = neighbours[lex.rank-nr_row]
self.vectors.add(lex.orth, row=lex.rank)
remap[lex.orth_] = (i2k[lex.rank], scores[lex.rank])
for key, row in self.vectors.key2row.items():
if row >= nr_row:
self.vectors.key2row[key] = neighbours[row-nr_row]
# Make copy, to encourage the original table to be garbage collected.
self.vectors.data = xp.ascontiguousarray(self.vectors.data[:nr_row])
link_vectors_to_models(self)
return remap
def get_vector(self, orth): def get_vector(self, orth):
"""Retrieve a vector for a word in the vocabulary. Words can be looked """Retrieve a vector for a word in the vocabulary. Words can be looked
@ -266,9 +327,11 @@ cdef class Vocab:
"""Set a vector for a word in the vocabulary. Words can be referenced """Set a vector for a word in the vocabulary. Words can be referenced
by string or int ID. by string or int ID.
""" """
if not isinstance(orth, basestring_): if self.vectors.data.size == 0:
orth = self.strings[orth] self.clear_vectors(vector.shape[0])
lex = self[orth]
self.vectors.add(orth, vector=vector) self.vectors.add(orth, vector=vector)
lex.rank = self.vectors.key2row[lex.orth]
def has_vector(self, orth): def has_vector(self, orth):
"""Check whether a word has a vector. Returns False if no vectors have """Check whether a word has a vector. Returns False if no vectors have