From 72aea8f1057d251c88306c11146c2a9c0ca0c3c2 Mon Sep 17 00:00:00 2001 From: Explosion Bot Date: Mon, 30 Oct 2017 10:03:08 +0100 Subject: [PATCH] Update vectors.add() to allow setting keys to rows --- spacy/tests/doc/test_doc_api.py | 2 +- spacy/tests/doc/test_token_api.py | 4 +-- spacy/vectors.pyx | 46 +++++++++++++++++++------------ 3 files changed, 32 insertions(+), 20 deletions(-) diff --git a/spacy/tests/doc/test_doc_api.py b/spacy/tests/doc/test_doc_api.py index 46c615973..8f881e811 100644 --- a/spacy/tests/doc/test_doc_api.py +++ b/spacy/tests/doc/test_doc_api.py @@ -209,7 +209,7 @@ def test_doc_api_right_edge(en_tokenizer): def test_doc_api_has_vector(): vocab = Vocab() vocab.clear_vectors(2) - vocab.vectors.add('kitten', numpy.asarray([0., 2.], dtype='f')) + vocab.vectors.add('kitten', vector=numpy.asarray([0., 2.], dtype='f')) doc = Doc(vocab, words=['kitten']) assert doc.has_vector diff --git a/spacy/tests/doc/test_token_api.py b/spacy/tests/doc/test_token_api.py index 0ab723f7a..a52be9731 100644 --- a/spacy/tests/doc/test_token_api.py +++ b/spacy/tests/doc/test_token_api.py @@ -73,8 +73,8 @@ def test_doc_token_api_is_properties(en_vocab): def test_doc_token_api_vectors(): vocab = Vocab() vocab.clear_vectors(2) - vocab.vectors.add('apples', numpy.asarray([0., 2.], dtype='f')) - vocab.vectors.add('oranges', numpy.asarray([0., 1.], dtype='f')) + vocab.vectors.add('apples', vector=numpy.asarray([0., 2.], dtype='f')) + vocab.vectors.add('oranges', vector=numpy.asarray([0., 1.], dtype='f')) doc = Doc(vocab, words=['apples', 'oranges', 'oov']) assert doc.has_vector diff --git a/spacy/vectors.pyx b/spacy/vectors.pyx index 155d7b9d2..d6b59401e 100644 --- a/spacy/vectors.pyx +++ b/spacy/vectors.pyx @@ -21,8 +21,10 @@ cdef class Vectors: Vectors data is kept in the vectors.data attribute, which should be an instance of numpy.ndarray (for CPU vectors) or cupy.ndarray (for GPU vectors). `vectors.key2row` is a dictionary mapping word hashes to - rows in the vectors.data table. The array `vectors.keys` keeps the keys in - order, such that `keys[vectors.key2row[key]] == key`. + rows in the vectors.data table. + + Multiple keys can be mapped to the same vector, so len(keys) may be greater + (but not smaller) than data.shape[0]. """ cdef public object data cdef readonly StringStore strings @@ -57,7 +59,7 @@ cdef class Vectors: for i, string in enumerate(self.strings): if i >= self.data.shape[0]: break - self.add(self.strings[string], self.data[i]) + self.add(self.strings[string], vector=self.data[i]) def __reduce__(self): return (Vectors, (self.strings, self.data)) @@ -114,27 +116,36 @@ cdef class Vectors: key = self.strings[key] return key in self.key2row - def add(self, key, vector=None): - """Add a key to the table, optionally setting a vector value as well. + def add(self, key, *, vector=None, row=None): + """Add a key to the table. Keys can be mapped to an existing vector + by setting `row`, or a new vector can be added. key (unicode / int): The key to add. - vector (numpy.ndarray): An optional vector to add. + vector (numpy.ndarray / None): A vector to add for the key. + row (int / None): The row-number of a vector to map the key to. """ + if row is not None and vector is not None: + raise ValueError("Only one of 'row' and 'vector' may be set") if isinstance(key, basestring_): key = self.strings.add(key) - if key not in self.key2row: - i = self.i - if i >= self.keys.shape[0]: - self.keys.resize((self.keys.shape[0]*2,)) - self.data.resize((self.data.shape[0]*2, self.data.shape[1])) - self.key2row[key] = self.i + if key in self.key2row and vector is not None: + row = self.key2row[key] + elif key in self.key2row and row is not None: + self.key2row[key] = row + elif key not in self.key2row: + if row is not None: + self.key2row[key] = row + else: + self.key2row[key] = self.i + row = self.i + if row >= self.keys.shape[0]: + self.keys.resize((row*2,)) + self.data.resize((row*2, self.data.shape[1])) self.keys[self.i] = key self.i += 1 - else: - i = self.key2row[key] if vector is not None: - self.data[i] = vector - return i + self.data[row] = vector + return row def items(self): """Iterate over `(string key, vector)` pairs, in order. @@ -143,7 +154,8 @@ cdef class Vectors: """ for i, key in enumerate(self.keys): string = self.strings[key] - yield string, self.data[i] + row = self.key2row[key] + yield string, self.data[row] @property def shape(self):