From 368fdb389ad23d07b99604483a8f96ff5a11e1d0 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 31 Oct 2017 02:00:26 +0100 Subject: [PATCH] WIP on refactoring and fixing vectors --- spacy/_ml.py | 16 ++++-- spacy/cli/train.py | 5 ++ spacy/tests/vocab/test_add_vectors.py | 44 ++++++++++++--- spacy/vectors.pyx | 57 ++++++++----------- spacy/vocab.pyx | 81 ++++++++++++++++++++++++--- 5 files changed, 147 insertions(+), 56 deletions(-) diff --git a/spacy/_ml.py b/spacy/_ml.py index c99f840b7..e9dac11df 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -29,6 +29,12 @@ from . import util VECTORS_KEY = 'spacy_pretrained_vectors' +def cosine(vec1, vec2): + norm1 = (vec1**2).sum() ** 0.5 + norm2 = (vec2**2).sum() ** 0.5 + return vec1.dot(vec2) / (norm1 * norm2) + + @layerize def _flatten_add_lengths(seqs, pad=0, drop=0.): ops = Model.ops @@ -198,11 +204,11 @@ class PrecomputableAffine(Model): def link_vectors_to_models(vocab): vectors = vocab.vectors ops = Model.ops - for word in vocab: - if word.orth in vectors.key2row: - word.rank = vectors.key2row[word.orth] - else: - word.rank = 0 + #for word in vocab: + # if word.orth in vectors.key2row: + # word.rank = vectors.key2row[word.orth] + # else: + # word.rank = 0 data = ops.asarray(vectors.data) # Set an entry here, so that vectors are accessed by StaticVectors # (unideal, I know) diff --git a/spacy/cli/train.py b/spacy/cli/train.py index fb96e6c05..2300c3b94 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -94,6 +94,11 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0, nlp.meta.update(meta) if vectors: util.load_model(vectors, vocab=nlp.vocab) + if vectors_limit is not None: + remap = nlp.vocab.prune_vectors(vectors_limit) + print('remap', len(remap)) + for key, (value, sim) in remap.items(): + print(repr(key), repr(value), sim) for name in pipeline: nlp.add_pipe(nlp.create_pipe(name), name=name) optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu) diff --git a/spacy/tests/vocab/test_add_vectors.py b/spacy/tests/vocab/test_add_vectors.py index 10477cdf1..0ce95e5e9 100644 --- a/spacy/tests/vocab/test_add_vectors.py +++ b/spacy/tests/vocab/test_add_vectors.py @@ -3,13 +3,41 @@ from __future__ import unicode_literals import numpy import pytest +from ...vocab import Vocab +from ..._ml import cosine -@pytest.mark.xfail -@pytest.mark.parametrize('text', ["Hello"]) -def test_vocab_add_vector(en_vocab, text): - en_vocab.resize_vectors(10) - lex = en_vocab[text] - lex.vector = numpy.ndarray((10,), dtype='float32') - lex = en_vocab[text] - assert lex.vector.shape == (10,) +def test_vocab_add_vector(): + vocab = Vocab() + data = numpy.ndarray((5,3), dtype='f') + data[0] = 1. + data[1] = 2. + vocab.set_vector(u'cat', data[0]) + vocab.set_vector(u'dog', data[1]) + cat = vocab[u'cat'] + assert list(cat.vector) == [1., 1., 1.] + dog = vocab[u'dog'] + assert list(dog.vector) == [2., 2., 2.] + for lex in vocab: + print(lex.orth_) + + +def test_vocab_prune_vectors(): + vocab = Vocab() + _ = vocab[u'cat'] + _ = vocab[u'dog'] + _ = vocab[u'kitten'] + print(list(vocab.strings)) + data = numpy.ndarray((5,3), dtype='f') + data[0] = 1. + data[1] = 2. + data[2] = 1.1 + vocab.set_vector(u'cat', data[0]) + vocab.set_vector(u'dog', data[1]) + vocab.set_vector(u'kitten', data[2]) + for lex in vocab: + print(lex.orth_) + + remap = vocab.prune_vectors(2) + assert remap == {u'kitten': (u'cat', cosine(data[0], data[2]))} + #print(remap) diff --git a/spacy/vectors.pyx b/spacy/vectors.pyx index 155d7b9d2..6a1bc876e 100644 --- a/spacy/vectors.pyx +++ b/spacy/vectors.pyx @@ -27,8 +27,7 @@ cdef class Vectors: cdef public object data cdef readonly StringStore strings cdef public object key2row - cdef public object keys - cdef public int i + cdef public int _i_vec def __init__(self, strings, width=0, data=None): """Create a new vector store. To keep the vector table empty, pass @@ -51,13 +50,13 @@ cdef class Vectors: self.data = numpy.asarray(data, dtype='f') else: self.data = numpy.zeros((len(self.strings), width), dtype='f') - self.i = 0 + self._i_vec = 0 self.key2row = {} - self.keys = numpy.zeros((self.data.shape[0],), dtype='uint64') - for i, string in enumerate(self.strings): - if i >= self.data.shape[0]: - break - self.add(self.strings[string], self.data[i]) + if data is not None: + for i, string in enumerate(self.strings): + if i >= self.data.shape[0]: + break + self.add(self.strings[string], vector=self.data[i]) def __reduce__(self): return (Vectors, (self.strings, self.data)) @@ -122,16 +121,15 @@ cdef class Vectors: """ if isinstance(key, basestring_): key = self.strings.add(key) - if key not in self.key2row: - i = self.i - if i >= self.keys.shape[0]: - self.keys.resize((self.keys.shape[0]*2,)) - self.data.resize((self.data.shape[0]*2, self.data.shape[1])) - self.key2row[key] = self.i - self.keys[self.i] = key - self.i += 1 - else: - i = self.key2row[key] + if row is None and key in self.key2row: + row = self.key2row[key] + elif row is None: + row = self._i_vec + self._i_vec += 1 + if row >= self.data.shape[0]: + self.data.resize((row*2, self.data.shape[1])) + + self.key2row[key] = row if vector is not None: self.data[i] = vector return i @@ -141,9 +139,9 @@ cdef class Vectors: YIELDS (tuple): A key/vector pair. """ - for i, key in enumerate(self.keys): + for key, row in self.key2row.items(): string = self.strings[key] - yield string, self.data[i] + yield string, self.data[row] @property def shape(self): @@ -202,7 +200,7 @@ cdef class Vectors: save_array = lambda arr, file_: xp.save(file_, arr) serializers = OrderedDict(( ('vectors', lambda p: save_array(self.data, p.open('wb'))), - ('keys', lambda p: xp.save(p.open('wb'), self.keys)) + ('key2row', lambda p: msgpack.dump(self.key2row, p.open('wb'))) )) return util.to_disk(path, serializers, exclude) @@ -215,10 +213,7 @@ cdef class Vectors: """ def load_keys(path): if path.exists(): - self.keys = numpy.load(path2str(path)) - for i, key in enumerate(self.keys): - self.keys[i] = key - self.key2row[key] = i + self.key2row = msgpack.load(path.open('rb')) def load_vectors(path): xp = Model.ops.xp @@ -226,7 +221,7 @@ cdef class Vectors: self.data = xp.load(path) serializers = OrderedDict(( - ('keys', load_keys), + ('key2row', load_keys), ('vectors', load_vectors), )) util.from_disk(path, serializers, exclude) @@ -244,7 +239,7 @@ cdef class Vectors: else: return msgpack.dumps(self.data) serializers = OrderedDict(( - ('keys', lambda: msgpack.dumps(self.keys)), + ('key2row', lambda: msgpack.dumps(self.key2row)), ('vectors', serialize_weights) )) return util.to_bytes(serializers, exclude) @@ -262,14 +257,8 @@ cdef class Vectors: else: self.data = msgpack.loads(b) - def load_keys(keys): - self.keys.resize((len(keys),)) - for i, key in enumerate(keys): - self.keys[i] = key - self.key2row[key] = i - deserializers = OrderedDict(( - ('keys', lambda b: load_keys(msgpack.loads(b))), + ('key2row', lambda b: self.key2row.update(msgpack.loads(b))), ('vectors', deserialize_weights) )) util.from_bytes(data, deserializers, exclude) diff --git a/spacy/vocab.pyx b/spacy/vocab.pyx index 8b09d7ee7..e3cad12e0 100644 --- a/spacy/vocab.pyx +++ b/spacy/vocab.pyx @@ -190,10 +190,11 @@ cdef class Vocab: YIELDS (Lexeme): An entry in the vocabulary. """ - cdef attr_t orth + cdef attr_t key cdef size_t addr - for orth, addr in self._by_orth.items(): - yield Lexeme(self, orth) + for key, addr in self._by_orth.items(): + lex = Lexeme(self, key) + yield lex def __getitem__(self, id_or_string): """Retrieve a lexeme, given an int ID or a unicode string. If a @@ -211,7 +212,7 @@ cdef class Vocab: >>> assert nlp.vocab[apple] == nlp.vocab[u'apple'] """ cdef attr_t orth - if type(id_or_string) == unicode: + if isinstance(id_or_string, unicode): orth = self.strings.add(id_or_string) else: orth = id_or_string @@ -242,9 +243,69 @@ cdef class Vocab: """Drop the current vector table. Because all vectors must be the same width, you have to call this to change the size of the vectors. """ - if new_dim is None: - new_dim = self.vectors.data.shape[1] - self.vectors = Vectors(self.strings, width=new_dim) + if width is None: + width = self.vectors.data.shape[1] + self.vectors = Vectors(self.strings, width=width) + + def prune_vectors(self, nr_row, batch_size=8): + """Reduce the current vector table to `nr_row` unique entries. Words + mapped to the discarded vectors will be remapped to the closest vector + among those remaining. + + For example, suppose the original table had vectors for the words: + ['sat', 'cat', 'feline', 'reclined']. If we prune the vector table to, + two rows, we would discard the vectors for 'feline' and 'reclined'. + These words would then be remapped to the closest remaining vector + -- so "feline" would have the same vector as "cat", and "reclined" + would have the same vector as "sat". + + The similarities are judged by cosine. The original vectors may + be large, so the cosines are calculated in minibatches, to reduce + memory usage. + + nr_row (int): The number of rows to keep in the vector table. + batch_size (int): Batch of vectors for calculating the similarities. + Larger batch sizes might be faster, while temporarily requiring + more memory. + RETURNS (dict): A dictionary keyed by removed words mapped to + `(string, score)` tuples, where `string` is the entry the removed + word was mapped to, and `score` the similarity score between the + two words. + """ + xp = get_array_module(self.vectors.data) + # Work in batches, to avoid memory problems. + keep = self.vectors.data[:nr_row] + keep_keys = [key for key, row in self.vectors.key2row.items() if row < nr_row] + toss = self.vectors.data[nr_row:] + # Normalize the vectors, so cosine similarity is just dot product. + # Note we can't modify the ones we're keeping in-place... + keep = keep / (xp.linalg.norm(keep, axis=1, keepdims=True)+1e-8) + keep = xp.ascontiguousarray(keep.T) + neighbours = xp.zeros((toss.shape[0],), dtype='i') + scores = xp.zeros((toss.shape[0],), dtype='f') + for i in range(0, toss.shape[0]//2, batch_size): + batch = toss[i : i+batch_size] + batch /= xp.linalg.norm(batch, axis=1, keepdims=True)+1e-8 + sims = xp.dot(batch, keep) + matches = sims.argmax(axis=1) + neighbours[i:i+batch_size] = matches + scores[i:i+batch_size] = sims.max(axis=1) + i2k = {i: key for key, i in self.vectors.key2row.items()} + remap = {} + for lex in list(self): + # If we're losing the vector for this word, map it to the nearest + # vector we're keeping. + if lex.rank >= nr_row: + lex.rank = neighbours[lex.rank-nr_row] + self.vectors.add(lex.orth, row=lex.rank) + remap[lex.orth_] = (i2k[lex.rank], scores[lex.rank]) + for key, row in self.vectors.key2row.items(): + if row >= nr_row: + self.vectors.key2row[key] = neighbours[row-nr_row] + # Make copy, to encourage the original table to be garbage collected. + self.vectors.data = xp.ascontiguousarray(self.vectors.data[:nr_row]) + link_vectors_to_models(self) + return remap def get_vector(self, orth): """Retrieve a vector for a word in the vocabulary. Words can be looked @@ -266,9 +327,11 @@ cdef class Vocab: """Set a vector for a word in the vocabulary. Words can be referenced by string or int ID. """ - if not isinstance(orth, basestring_): - orth = self.strings[orth] + if self.vectors.data.size == 0: + self.clear_vectors(vector.shape[0]) + lex = self[orth] self.vectors.add(orth, vector=vector) + lex.rank = self.vectors.key2row[lex.orth] def has_vector(self, orth): """Check whether a word has a vector. Returns False if no vectors have