diff --git a/spacy/_ml.py b/spacy/_ml.py index c99f840b7..6bfacb20a 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -29,6 +29,16 @@ from . import util VECTORS_KEY = 'spacy_pretrained_vectors' +def cosine(vec1, vec2): + xp = get_array_module(vec1) + norm1 = xp.linalg.norm(vec1) + norm2 = xp.linalg.norm(vec2) + if norm1 == 0. or norm2 == 0.: + return 0 + else: + return vec1.dot(vec2) / (norm1 * norm2) + + @layerize def _flatten_add_lengths(seqs, pad=0, drop=0.): ops = Model.ops diff --git a/spacy/cli/train.py b/spacy/cli/train.py index 34117db22..74e1d6d68 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -32,7 +32,6 @@ numpy.random.seed(0) n_sents=("number of sentences", "option", "ns", int), use_gpu=("Use GPU", "option", "g", int), vectors=("Model to load vectors from", "option", "v"), - vectors_limit=("Truncate to N vectors (requires -v)", "option", None, int), no_tagger=("Don't train tagger", "flag", "T", bool), no_parser=("Don't train parser", "flag", "P", bool), no_entities=("Don't train NER", "flag", "N", bool), @@ -41,7 +40,7 @@ numpy.random.seed(0) meta_path=("Optional path to meta.json. All relevant properties will be " "overwritten.", "option", "m", Path)) def train(cmd, lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0, - use_gpu=-1, vectors=None, vectors_limit=None, no_tagger=False, + use_gpu=-1, vectors=None, no_tagger=False, no_parser=False, no_entities=False, gold_preproc=False, version="0.0.0", meta_path=None): """ @@ -95,8 +94,6 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0, nlp.meta.update(meta) if vectors: util.load_model(vectors, vocab=nlp.vocab) - if vectors_limit is not None: - nlp.vocab.prune_vectors(vectors_limit) for name in pipeline: nlp.add_pipe(nlp.create_pipe(name), name=name) optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu) diff --git a/spacy/cli/vocab.py b/spacy/cli/vocab.py index d05eff3f0..5f6f58d80 100644 --- a/spacy/cli/vocab.py +++ b/spacy/cli/vocab.py @@ -7,6 +7,7 @@ import spacy import numpy from pathlib import Path +from ..vectors import Vectors from ..util import prints, ensure_path @@ -16,8 +17,12 @@ from ..util import prints, ensure_path lexemes_loc=("location of JSONL-formatted lexical data", "positional", None, Path), vectors_loc=("optional: location of vectors data, as numpy .npz", - "positional", None, str)) -def make_vocab(cmd, lang, output_dir, lexemes_loc, vectors_loc=None): + "positional", None, str), + prune_vectors=("optional: number of vectors to prune to.", + "option", "V", int) +) +def make_vocab(cmd, lang, output_dir, lexemes_loc, + vectors_loc=None, prune_vectors=-1): """Compile a vocabulary from a lexicon jsonl file and word vectors.""" if not lexemes_loc.exists(): prints(lexemes_loc, title="Can't find lexical data", exits=1) @@ -26,7 +31,6 @@ def make_vocab(cmd, lang, output_dir, lexemes_loc, vectors_loc=None): for word in nlp.vocab: word.rank = 0 lex_added = 0 - vec_added = 0 with lexemes_loc.open() as file_: for line in file_: if line.strip(): @@ -39,16 +43,18 @@ def make_vocab(cmd, lang, output_dir, lexemes_loc, vectors_loc=None): assert lex.rank == attrs['id'] lex_added += 1 if vectors_loc is not None: - vector_data = numpy.load(open(vectors_loc, 'rb')) - nlp.vocab.clear_vectors(width=vector_data.shape[1]) + vector_data = numpy.load(vectors_loc.open('rb')) + nlp.vocab.vectors = Vectors(data=vector_data) for word in nlp.vocab: if word.rank: - nlp.vocab.vectors.add(word.orth_, row=word.rank, - vector=vector_data[word.rank]) - vec_added += 1 + nlp.vocab.vectors.add(word.orth, row=word.rank) + + if prune_vectors >= 1: + remap = nlp.vocab.prune_vectors(prune_vectors) if not output_dir.exists(): output_dir.mkdir() nlp.to_disk(output_dir) + vec_added = len(nlp.vocab.vectors) prints("{} entries, {} vectors".format(lex_added, vec_added), output_dir, title="Sucessfully compiled vocab and vectors, and saved model") return nlp diff --git a/spacy/tests/doc/test_doc_api.py b/spacy/tests/doc/test_doc_api.py index 8f881e811..2c90572e3 100644 --- a/spacy/tests/doc/test_doc_api.py +++ b/spacy/tests/doc/test_doc_api.py @@ -208,8 +208,8 @@ def test_doc_api_right_edge(en_tokenizer): def test_doc_api_has_vector(): vocab = Vocab() - vocab.clear_vectors(2) - vocab.vectors.add('kitten', vector=numpy.asarray([0., 2.], dtype='f')) + vocab.reset_vectors(width=2) + vocab.set_vector('kitten', vector=numpy.asarray([0., 2.], dtype='f')) doc = Doc(vocab, words=['kitten']) assert doc.has_vector diff --git a/spacy/tests/doc/test_token_api.py b/spacy/tests/doc/test_token_api.py index a52be9731..c02904905 100644 --- a/spacy/tests/doc/test_token_api.py +++ b/spacy/tests/doc/test_token_api.py @@ -72,9 +72,9 @@ def test_doc_token_api_is_properties(en_vocab): def test_doc_token_api_vectors(): vocab = Vocab() - vocab.clear_vectors(2) - vocab.vectors.add('apples', vector=numpy.asarray([0., 2.], dtype='f')) - vocab.vectors.add('oranges', vector=numpy.asarray([0., 1.], dtype='f')) + vocab.reset_vectors(width=2) + vocab.set_vector('apples', vector=numpy.asarray([0., 2.], dtype='f')) + vocab.set_vector('oranges', vector=numpy.asarray([0., 1.], dtype='f')) doc = Doc(vocab, words=['apples', 'oranges', 'oov']) assert doc.has_vector diff --git a/spacy/tests/util.py b/spacy/tests/util.py index 2f474a926..2de97583c 100644 --- a/spacy/tests/util.py +++ b/spacy/tests/util.py @@ -79,9 +79,9 @@ def add_vecs_to_vocab(vocab, vectors): """Add list of vector tuples to given vocab. All vectors need to have the same length. Format: [("text", [1, 2, 3])]""" length = len(vectors[0][1]) - vocab.clear_vectors(length) + vocab.reset_vectors(width=length) for word, vec in vectors: - vocab.set_vector(word, vec) + vocab.set_vector(word, vector=vec) return vocab diff --git a/spacy/tests/vectors/test_vectors.py b/spacy/tests/vectors/test_vectors.py index 74ac26a10..ce183f9fd 100644 --- a/spacy/tests/vectors/test_vectors.py +++ b/spacy/tests/vectors/test_vectors.py @@ -35,20 +35,18 @@ def vocab(en_vocab, vectors): def test_init_vectors_with_data(strings, data): - v = Vectors(strings, data=data) + v = Vectors(data=data) assert v.shape == data.shape -def test_init_vectors_with_width(strings): - v = Vectors(strings, width=3) - for string in strings: - v.add(string) +def test_init_vectors_with_shape(strings): + v = Vectors(shape=(len(strings), 3)) assert v.shape == (len(strings), 3) def test_get_vector(strings, data): - v = Vectors(strings, data=data) - for string in strings: - v.add(string) + v = Vectors(data=data) + for i, string in enumerate(strings): + v.add(string, row=i) assert list(v[strings[0]]) == list(data[0]) assert list(v[strings[0]]) != list(data[1]) assert list(v[strings[1]]) != list(data[0]) @@ -56,9 +54,9 @@ def test_get_vector(strings, data): def test_set_vector(strings, data): orig = data.copy() - v = Vectors(strings, data=data) - for string in strings: - v.add(string) + v = Vectors(data=data) + for i, string in enumerate(strings): + v.add(string, row=i) assert list(v[strings[0]]) == list(orig[0]) assert list(v[strings[0]]) != list(orig[1]) v[strings[0]] = data[1] @@ -66,7 +64,6 @@ def test_set_vector(strings, data): assert list(v[strings[0]]) != list(orig[0]) - @pytest.fixture() def tokenizer_v(vocab): return Tokenizer(vocab, {}, None, None, None) diff --git a/spacy/tests/vocab/test_add_vectors.py b/spacy/tests/vocab/test_add_vectors.py index 10477cdf1..3ef599678 100644 --- a/spacy/tests/vocab/test_add_vectors.py +++ b/spacy/tests/vocab/test_add_vectors.py @@ -2,14 +2,39 @@ from __future__ import unicode_literals import numpy -import pytest +from numpy.testing import assert_allclose +from ...vocab import Vocab +from ..._ml import cosine -@pytest.mark.xfail -@pytest.mark.parametrize('text', ["Hello"]) -def test_vocab_add_vector(en_vocab, text): - en_vocab.resize_vectors(10) - lex = en_vocab[text] - lex.vector = numpy.ndarray((10,), dtype='float32') - lex = en_vocab[text] - assert lex.vector.shape == (10,) +def test_vocab_add_vector(): + vocab = Vocab() + data = numpy.ndarray((5,3), dtype='f') + data[0] = 1. + data[1] = 2. + vocab.set_vector(u'cat', data[0]) + vocab.set_vector(u'dog', data[1]) + cat = vocab[u'cat'] + assert list(cat.vector) == [1., 1., 1.] + dog = vocab[u'dog'] + assert list(dog.vector) == [2., 2., 2.] + + +def test_vocab_prune_vectors(): + vocab = Vocab() + _ = vocab[u'cat'] + _ = vocab[u'dog'] + _ = vocab[u'kitten'] + data = numpy.ndarray((5,3), dtype='f') + data[0] = 1. + data[1] = 2. + data[2] = 1.1 + vocab.set_vector(u'cat', data[0]) + vocab.set_vector(u'dog', data[1]) + vocab.set_vector(u'kitten', data[2]) + + remap = vocab.prune_vectors(2) + assert list(remap.keys()) == [u'kitten'] + neighbour, similarity = list(remap.values())[0] + assert neighbour == u'cat', remap + assert_allclose(similarity, cosine(data[0], data[2]), atol=1e-6) diff --git a/spacy/vectors.pyx b/spacy/vectors.pyx index 552a6bcf3..08ab586d1 100644 --- a/spacy/vectors.pyx +++ b/spacy/vectors.pyx @@ -15,6 +15,12 @@ from .compat import basestring_, path2str from . import util +def unpickle_vectors(keys_and_rows, data): + vectors = Vectors(data=data) + for key, row in keys_and_rows: + vectors.add(key, row=row) + + cdef class Vectors: """Store, save and load word vectors. @@ -23,140 +29,35 @@ cdef class Vectors: (for GPU vectors). `vectors.key2row` is a dictionary mapping word hashes to rows in the vectors.data table. - Multiple keys can be mapped to the same vector, so len(keys) may be greater - (but not smaller) than data.shape[0]. + Multiple keys can be mapped to the same vector, and not all of the rows in + the table need to be assigned --- so len(list(vectors.keys())) may be + greater or smaller than vectors.shape[0]. """ cdef public object data - cdef readonly StringStore strings cdef public object key2row - cdef public object keys - cdef public int _i_key - cdef public int _i_vec + cdef public object _unset - def __init__(self, strings, width=0, data=None): - """Create a new vector store. To keep the vector table empty, pass - `width=0`. You can also create the vector table and add vectors one by - one, or set the vector values directly on initialisation. - - strings (StringStore or list): List of strings or StringStore that maps - strings to hash values, and vice versa. - width (int): Number of dimensions. + def __init__(self, *, shape=None, data=None, keys=None): + """Create a new vector store. + + shape (tuple): Size of the table, as (# entries, # columns) data (numpy.ndarray): The vector data. RETURNS (Vectors): The newly created object. """ - if isinstance(strings, StringStore): - self.strings = strings + if data is None: + if shape is None: + shape = (0,0) + data = numpy.zeros(shape, dtype='f') + self.data = data + self.key2row = OrderedDict() + if self.data is not None: + self._unset = set(range(self.data.shape[0])) else: - self.strings = StringStore() - for string in strings: - self.strings.add(string) - if data is not None: - self.data = numpy.asarray(data, dtype='f') - else: - self.data = numpy.zeros((len(self.strings), width), dtype='f') - self._i_key = 0 - self._i_vec = 0 - self.key2row = {} - self.keys = numpy.zeros((self.data.shape[0],), dtype='uint64') - if data is not None: - for i, string in enumerate(self.strings): - if i >= self.data.shape[0]: - break - self.add(self.strings[string], vector=self.data[i]) - - def __reduce__(self): - return (Vectors, (self.strings, self.data)) - - def __getitem__(self, key): - """Get a vector by key. If key is a string, it is hashed to an integer - ID using the vectors.strings table. If the integer key is not found in - the table, a KeyError is raised. - - key (unicode / int): The key to get the vector for. - RETURNS (numpy.ndarray): The vector for the key. - """ - if isinstance(key, basestring): - key = self.strings[key] - i = self.key2row[key] - if i is None: - raise KeyError(key) - else: - return self.data[i] - - def __setitem__(self, key, vector): - """Set a vector for the given key. If key is a string, it is hashed - to an integer ID using the vectors.strings table. - - key (unicode / int): The key to set the vector for. - vector (numpy.ndarray): The vector to set. - """ - if isinstance(key, basestring): - key = self.strings.add(key) - i = self.key2row[key] - self.data[i] = vector - - def __iter__(self): - """Yield vectors from the table. - - YIELDS (numpy.ndarray): A vector. - """ - yield from self.data - - def __len__(self): - """Return the number of vectors that have been assigned. - - RETURNS (int): The number of vectors in the data. - """ - return self._i_vec - - def __contains__(self, key): - """Check whether a key has a vector entry in the table. - - key (unicode / int): The key to check. - RETURNS (bool): Whether the key has a vector entry. - """ - if isinstance(key, basestring_): - key = self.strings[key] - return key in self.key2row - - def add(self, key, *, vector=None, row=None): - """Add a key to the table. Keys can be mapped to an existing vector - by setting `row`, or a new vector can be added. - - key (unicode / int): The key to add. - vector (numpy.ndarray / None): A vector to add for the key. - row (int / None): The row-number of a vector to map the key to. - """ - if isinstance(key, basestring_): - key = self.strings.add(key) - if row is None and key in self.key2row: - row = self.key2row[key] - elif row is None: - row = self._i_vec - self._i_vec += 1 - if row >= self.data.shape[0]: - self.data.resize((row*2, self.data.shape[1])) - if key not in self.key2row: - if self._i_key >= self.keys.shape[0]: - self.keys.resize((self._i_key*2,)) - self.keys[self._i_key] = key - self._i_key += 1 - - self.key2row[key] = row - if vector is not None: - self.data[row] = vector - return row - - def items(self): - """Iterate over `(string key, vector)` pairs, in order. - - YIELDS (tuple): A key/vector pair. - """ - for i, key in enumerate(self.keys): - string = self.strings[key] - row = self.key2row[key] - yield string, self.data[row] - + self._unset = set() + if keys is not None: + for i, key in enumerate(keys): + self.add(key, row=i) + @property def shape(self): """Get `(rows, dims)` tuples of number of rows and number of dimensions @@ -166,9 +67,184 @@ cdef class Vectors: """ return self.data.shape - def most_similar(self, key): - # TODO: implement - raise NotImplementedError + @property + def size(self): + """Return rows*dims""" + return self.data.shape[0] * self.data.shape[1] + + @property + def is_full(self): + """Returns True if no keys are available for new keys.""" + return len(self._unset) == 0 + + @property + def n_keys(self): + """Returns True if no keys are available for new keys.""" + return len(self.key2row) + + def __reduce__(self): + keys_and_rows = self.key2row.items() + return (unpickle_vectors, (keys_and_rows, self.data)) + + def __getitem__(self, key): + """Get a vector by key. If the key is not found, a KeyError is raised. + + key (int): The key to get the vector for. + RETURNS (ndarray): The vector for the key. + """ + i = self.key2row[key] + if i is None: + raise KeyError(key) + else: + return self.data[i] + + def __setitem__(self, key, vector): + """Set a vector for the given key. + + key (int): The key to set the vector for. + vector (numpy.ndarray): The vector to set. + """ + i = self.key2row[key] + self.data[i] = vector + if i in self._unset: + self._unset.remove(i) + + def __iter__(self): + """Yield vectors from the table. + + YIELDS (ndarray): A vector. + """ + yield from self.key2row + + def __len__(self): + """Return the number of vectors in the table. + + RETURNS (int): The number of vectors in the data. + """ + return self.data.shape[0] + + def __contains__(self, key): + """Check whether a key has been mapped to a vector entry in the table. + + key (int): The key to check. + RETURNS (bool): Whether the key has a vector entry. + """ + return key in self.key2row + + def resize(self, shape, inplace=False): + '''Resize the underlying vectors array. If inplace=True, the memory + is reallocated. This may cause other references to the data to become + invalid, so only use inplace=True if you're sure that's what you want. + + If the number of vectors is reduced, keys mapped to rows that have been + deleted are removed. These removed items are returned as a list of + (key, row) tuples. + ''' + if inplace: + self.data.resize(shape, refcheck=False) + else: + xp = get_array_module(self.data) + self.data = xp.resize(self.data, shape) + filled = {row for row in self.key2row.values()} + self._unset = {row for row in range(shape[0]) if row not in filled} + removed_items = [] + for key, row in dict(self.key2row.items()): + if row >= shape[0]: + self.key2row.pop(key) + removed_items.append((key, row)) + return removed_items + + def keys(self): + '''Iterate over the keys in the table.''' + yield from self.key2row.keys() + + def values(self): + '''Iterate over vectors that have been assigned to at least one key. + + Note that some vectors may be unassigned, so the number of vectors + returned may be less than the length of the vectors table.''' + for row, vector in enumerate(range(self.data.shape[0])): + if row not in self._unset: + yield vector + + def items(self): + """Iterate over `(key, vector)` pairs. + + YIELDS (tuple): A key/vector pair. + """ + for key, row in self.key2row.items(): + yield key, self.data[row] + + def get_keys(self, rows): + xp = get_array_module(self.data) + row2key = {row: key for key, row in self.key2row.items()} + keys = xp.asarray([row2key[row] for row in rows], + dtype='uint64') + return keys + + def get_rows(self, keys): + xp = get_array_module(self.data) + k2r = self.key2row + return xp.asarray([k2r.get(key, -1) for key in keys], dtype='i') + + def add(self, key, *, vector=None, row=None): + """Add a key to the table. Keys can be mapped to an existing vector + by setting `row`, or a new vector can be added. + + key (unicode / int): The key to add. + vector (numpy.ndarray / None): A vector to add for the key. + row (int / None): The row-number of a vector to map the key to. + """ + if row is None and key in self.key2row: + row = self.key2row[key] + elif row is None: + if self.is_full: + raise ValueError("Cannot add new key to vectors -- full") + row = min(self._unset) + + self.key2row[key] = row + if vector is not None: + self.data[row] = vector + if row in self._unset: + self._unset.remove(row) + return row + + def most_similar(self, queries, *, return_scores=False, return_rows=False, + batch_size=1024): + '''For each of the given vectors, find the single entry most similar + to it, by cosine. + + Queries are by vector. Results are returned as an array of keys, + or a tuple of (keys, scores) if return_scores=True. If `queries` is + large, the calculations are performed in chunks, to avoid consuming + too much memory. You can set the `batch_size` to control the size/space + trade-off during the calculations. + ''' + xp = get_array_module(self.data) + + vectors = self.data / xp.linalg.norm(self.data, axis=1, keepdims=True) + + best_rows = xp.zeros((queries.shape[0],), dtype='i') + scores = xp.zeros((queries.shape[0],), dtype='f') + # Work in batches, to avoid memory problems. + for i in range(0, queries.shape[0], batch_size): + batch = queries[i : i+batch_size] + batch /= xp.linalg.norm(batch, axis=1, keepdims=True) + # batch e.g. (1024, 300) + # vectors e.g. (10000, 300) + # sims e.g. (1024, 10000) + sims = xp.dot(batch, vectors.T) + best_rows[i:i+batch_size] = sims.argmax(axis=1) + scores[i:i+batch_size] = sims.max(axis=1) + keys = self.get_keys(best_rows) + if return_rows and return_scores: + return (keys, best_rows, scores) + elif return_rows: + return (keys, best_rows) + elif return_scores: + return (keys, scores) + else: + return keys def from_glove(self, path): """Load GloVe vectors from a directory. Assumes binary format, @@ -178,27 +254,33 @@ cdef class Vectors: By default GloVe outputs 64-bit vectors. path (unicode / Path): The path to load the GloVe vectors from. + + RETURNS: A StringStore object, holding the key-to-string mapping. """ path = util.ensure_path(path) + width = None for name in path.iterdir(): if name.parts[-1].startswith('vectors'): _, dims, dtype, _2 = name.parts[-1].split('.') - self.width = int(dims) + width = int(dims) break else: raise IOError("Expected file named e.g. vectors.128.f.bin") bin_loc = path / 'vectors.{dims}.{dtype}.bin'.format(dims=dims, dtype=dtype) + xp = get_array_module(self.data) + self.data = None with bin_loc.open('rb') as file_: - self.data = numpy.fromfile(file_, dtype='float64') - self.data = numpy.ascontiguousarray(self.data, dtype='float32') + self.data = xp.fromfile(file_, dtype=dtype) + if dtype != 'float32': + self.data = xp.ascontiguousarray(self.data, dtype='float32') n = 0 + strings = StringStore() with (path / 'vocab.txt').open('r') as file_: - for line in file_: - self.add(line.strip()) - n += 1 - if (self.data.size % self.width) == 0: - self.data + for i, line in enumerate(file_): + key = strings.add(line.strip()) + self.add(key, row=i) + return strings def to_disk(self, path, **exclude): """Save the current state to a directory. @@ -214,7 +296,7 @@ cdef class Vectors: save_array = lambda arr, file_: xp.save(file_, arr) serializers = OrderedDict(( ('vectors', lambda p: save_array(self.data, p.open('wb'))), - ('keys', lambda p: xp.save(p.open('wb'), self.keys)) + ('key2row', lambda p: msgpack.dump(self.key2row, p.open('wb'))) )) return util.to_disk(path, serializers, exclude) @@ -225,12 +307,18 @@ cdef class Vectors: path (unicode / Path): Directory path, string or Path-like object. RETURNS (Vectors): The modified object. """ + def load_key2row(path): + if path.exists(): + self.key2row = msgpack.load(path.open('rb')) + for key, row in self.key2row.items(): + if row in self._unset: + self._unset.remove(row) + def load_keys(path): if path.exists(): - self.keys = numpy.load(path2str(path)) - for i, key in enumerate(self.keys): - self.keys[i] = key - self.key2row[key] = i + keys = numpy.load(str(path)) + for i, key in enumerate(keys): + self.add(key, row=i) def load_vectors(path): xp = Model.ops.xp @@ -238,6 +326,7 @@ cdef class Vectors: self.data = xp.load(path) serializers = OrderedDict(( + ('key2row', load_key2row), ('keys', load_keys), ('vectors', load_vectors), )) @@ -256,7 +345,7 @@ cdef class Vectors: else: return msgpack.dumps(self.data) serializers = OrderedDict(( - ('keys', lambda: msgpack.dumps(self.keys)), + ('key2row', lambda: msgpack.dumps(self.key2row)), ('vectors', serialize_weights) )) return util.to_bytes(serializers, exclude) @@ -274,14 +363,8 @@ cdef class Vectors: else: self.data = msgpack.loads(b) - def load_keys(keys): - self.keys.resize((len(keys),)) - for i, key in enumerate(keys): - self.keys[i] = key - self.key2row[key] = i - deserializers = OrderedDict(( - ('keys', lambda b: load_keys(msgpack.loads(b))), + ('key2row', lambda b: self.key2row.update(msgpack.loads(b))), ('vectors', deserialize_weights) )) util.from_bytes(data, deserializers, exclude) diff --git a/spacy/vocab.pyx b/spacy/vocab.pyx index 0e6b69ebd..14b62a808 100644 --- a/spacy/vocab.pyx +++ b/spacy/vocab.pyx @@ -55,7 +55,7 @@ cdef class Vocab: _ = self[string] self.lex_attr_getters = lex_attr_getters self.morphology = Morphology(self.strings, tag_map, lemmatizer) - self.vectors = Vectors(self.strings, width=0) + self.vectors = Vectors() property lang: def __get__(self): @@ -192,10 +192,11 @@ cdef class Vocab: YIELDS (Lexeme): An entry in the vocabulary. """ - cdef attr_t orth + cdef attr_t key cdef size_t addr - for orth, addr in self._by_orth.items(): - yield Lexeme(self, orth) + for key, addr in self._by_orth.items(): + lex = Lexeme(self, key) + yield lex def __getitem__(self, id_or_string): """Retrieve a lexeme, given an int ID or a unicode string. If a @@ -213,7 +214,7 @@ cdef class Vocab: >>> assert nlp.vocab[apple] == nlp.vocab[u'apple'] """ cdef attr_t orth - if type(id_or_string) == unicode: + if isinstance(id_or_string, unicode): orth = self.strings.add(id_or_string) else: orth = id_or_string @@ -240,15 +241,19 @@ cdef class Vocab: def vectors_length(self): return self.vectors.data.shape[1] - def clear_vectors(self, width=None): + def reset_vectors(self, *, width=None, shape=None): """Drop the current vector table. Because all vectors must be the same width, you have to call this to change the size of the vectors. """ - if width is None: - width = self.vectors.data.shape[1] - self.vectors = Vectors(self.strings, width=width) + if width is not None and shape is not None: + raise ValueError("Only one of width and shape can be specified") + elif shape is not None: + self.vectors = Vectors(shape=shape) + else: + width = width if width is not None else self.vectors.data.shape[1] + self.vectors = Vectors(shape=(self.vectors.shape[0], width)) - def prune_vectors(self, nr_row, batch_size=8): + def prune_vectors(self, nr_row, batch_size=1024): """Reduce the current vector table to `nr_row` unique entries. Words mapped to the discarded vectors will be remapped to the closest vector among those remaining. @@ -274,36 +279,31 @@ cdef class Vocab: two words. """ xp = get_array_module(self.vectors.data) - # Work in batches, to avoid memory problems. - keep = self.vectors.data[:nr_row] - keep_keys = [key for key, row in self.vectors.key2row.items() if row < nr_row] - toss = self.vectors.data[nr_row:] - # Normalize the vectors, so cosine similarity is just dot product. - # Note we can't modify the ones we're keeping in-place... - keep = keep / (xp.linalg.norm(keep, axis=1, keepdims=True)+1e-8) - keep = xp.ascontiguousarray(keep.T) - neighbours = xp.zeros((toss.shape[0],), dtype='i') - scores = xp.zeros((toss.shape[0],), dtype='f') - for i in range(0, toss.shape[0], batch_size): - batch = toss[i : i+batch_size] - batch /= xp.linalg.norm(batch, axis=1, keepdims=True)+1e-8 - sims = xp.dot(batch, keep) - matches = sims.argmax(axis=1) - neighbours[i:i+batch_size] = matches - scores[i:i+batch_size] = sims.max(axis=1) - for lex in self: - # If we're losing the vector for this word, map it to the nearest - # vector we're keeping. - if lex.rank >= nr_row: - lex.rank = neighbours[lex.rank-nr_row] - self.vectors.add(lex.orth, row=lex.rank) - for key in self.vectors.keys: - row = self.vectors.key2row[key] - if row >= nr_row: - self.vectors.key2row[key] = neighbours[row-nr_row] - # Make copy, to encourage the original table to be garbage collected. - self.vectors.data = xp.ascontiguousarray(self.vectors.data[:nr_row]) - # TODO: return new mapping + # Make prob negative so it sorts by rank ascending + # (key2row contains the rank) + priority = [(-lex.prob, self.vectors.key2row[lex.orth], lex.orth) + for lex in self if lex.orth in self.vectors.key2row] + priority.sort() + indices = xp.asarray([i for (prob, i, key) in priority], dtype='i') + keys = xp.asarray([key for (prob, i, key) in priority], dtype='uint64') + + keep = xp.ascontiguousarray(self.vectors.data[indices[:nr_row]]) + toss = xp.ascontiguousarray(self.vectors.data[indices[nr_row:]]) + + self.vectors = Vectors(data=keep, keys=keys) + + syn_keys, syn_rows, scores = self.vectors.most_similar(toss, + return_rows=True, return_scores=True) + + remap = {} + for i, key in enumerate(keys[nr_row:]): + self.vectors.add(key, row=syn_rows[i]) + word = self.strings[key] + synonym = self.strings[syn_keys[i]] + score = scores[i] + remap[word] = (synonym, score) + link_vectors_to_models(self) + return remap def get_vector(self, orth): """Retrieve a vector for a word in the vocabulary. Words can be looked @@ -325,8 +325,16 @@ cdef class Vocab: """Set a vector for a word in the vocabulary. Words can be referenced by string or int ID. """ - if not isinstance(orth, basestring_): - orth = self.strings[orth] + if isinstance(orth, basestring_): + orth = self.strings.add(orth) + if self.vectors.is_full and orth not in self.vectors: + new_rows = max(100, int(self.vectors.shape[0]*1.3)) + if self.vectors.shape[1] == 0: + width = vector.size + else: + width = self.vectors.shape[1] + self.vectors.resize((new_rows, width)) + self.vectors.add(orth, vector=vector) self.vectors.add(orth, vector=vector) def has_vector(self, orth):