Update vectors.add() to allow setting keys to rows

This commit is contained in:
Explosion Bot 2017-10-30 10:03:08 +01:00
parent 256c7dac5a
commit 72aea8f105
3 changed files with 32 additions and 20 deletions

View File

@ -209,7 +209,7 @@ def test_doc_api_right_edge(en_tokenizer):
def test_doc_api_has_vector(): def test_doc_api_has_vector():
vocab = Vocab() vocab = Vocab()
vocab.clear_vectors(2) vocab.clear_vectors(2)
vocab.vectors.add('kitten', numpy.asarray([0., 2.], dtype='f')) vocab.vectors.add('kitten', vector=numpy.asarray([0., 2.], dtype='f'))
doc = Doc(vocab, words=['kitten']) doc = Doc(vocab, words=['kitten'])
assert doc.has_vector assert doc.has_vector

View File

@ -73,8 +73,8 @@ def test_doc_token_api_is_properties(en_vocab):
def test_doc_token_api_vectors(): def test_doc_token_api_vectors():
vocab = Vocab() vocab = Vocab()
vocab.clear_vectors(2) vocab.clear_vectors(2)
vocab.vectors.add('apples', numpy.asarray([0., 2.], dtype='f')) vocab.vectors.add('apples', vector=numpy.asarray([0., 2.], dtype='f'))
vocab.vectors.add('oranges', numpy.asarray([0., 1.], dtype='f')) vocab.vectors.add('oranges', vector=numpy.asarray([0., 1.], dtype='f'))
doc = Doc(vocab, words=['apples', 'oranges', 'oov']) doc = Doc(vocab, words=['apples', 'oranges', 'oov'])
assert doc.has_vector assert doc.has_vector

View File

@ -21,8 +21,10 @@ cdef class Vectors:
Vectors data is kept in the vectors.data attribute, which should be an Vectors data is kept in the vectors.data attribute, which should be an
instance of numpy.ndarray (for CPU vectors) or cupy.ndarray instance of numpy.ndarray (for CPU vectors) or cupy.ndarray
(for GPU vectors). `vectors.key2row` is a dictionary mapping word hashes to (for GPU vectors). `vectors.key2row` is a dictionary mapping word hashes to
rows in the vectors.data table. The array `vectors.keys` keeps the keys in rows in the vectors.data table.
order, such that `keys[vectors.key2row[key]] == key`.
Multiple keys can be mapped to the same vector, so len(keys) may be greater
(but not smaller) than data.shape[0].
""" """
cdef public object data cdef public object data
cdef readonly StringStore strings cdef readonly StringStore strings
@ -57,7 +59,7 @@ cdef class Vectors:
for i, string in enumerate(self.strings): for i, string in enumerate(self.strings):
if i >= self.data.shape[0]: if i >= self.data.shape[0]:
break break
self.add(self.strings[string], self.data[i]) self.add(self.strings[string], vector=self.data[i])
def __reduce__(self): def __reduce__(self):
return (Vectors, (self.strings, self.data)) return (Vectors, (self.strings, self.data))
@ -114,27 +116,36 @@ cdef class Vectors:
key = self.strings[key] key = self.strings[key]
return key in self.key2row return key in self.key2row
def add(self, key, vector=None): def add(self, key, *, vector=None, row=None):
"""Add a key to the table, optionally setting a vector value as well. """Add a key to the table. Keys can be mapped to an existing vector
by setting `row`, or a new vector can be added.
key (unicode / int): The key to add. key (unicode / int): The key to add.
vector (numpy.ndarray): An optional vector to add. vector (numpy.ndarray / None): A vector to add for the key.
row (int / None): The row-number of a vector to map the key to.
""" """
if row is not None and vector is not None:
raise ValueError("Only one of 'row' and 'vector' may be set")
if isinstance(key, basestring_): if isinstance(key, basestring_):
key = self.strings.add(key) key = self.strings.add(key)
if key not in self.key2row: if key in self.key2row and vector is not None:
i = self.i row = self.key2row[key]
if i >= self.keys.shape[0]: elif key in self.key2row and row is not None:
self.keys.resize((self.keys.shape[0]*2,)) self.key2row[key] = row
self.data.resize((self.data.shape[0]*2, self.data.shape[1])) elif key not in self.key2row:
self.key2row[key] = self.i if row is not None:
self.key2row[key] = row
else:
self.key2row[key] = self.i
row = self.i
if row >= self.keys.shape[0]:
self.keys.resize((row*2,))
self.data.resize((row*2, self.data.shape[1]))
self.keys[self.i] = key self.keys[self.i] = key
self.i += 1 self.i += 1
else:
i = self.key2row[key]
if vector is not None: if vector is not None:
self.data[i] = vector self.data[row] = vector
return i return row
def items(self): def items(self):
"""Iterate over `(string key, vector)` pairs, in order. """Iterate over `(string key, vector)` pairs, in order.
@ -143,7 +154,8 @@ cdef class Vectors:
""" """
for i, key in enumerate(self.keys): for i, key in enumerate(self.keys):
string = self.strings[key] string = self.strings[key]
yield string, self.data[i] row = self.key2row[key]
yield string, self.data[row]
@property @property
def shape(self): def shape(self):