Update vectors.add() to allow setting keys to rows

This commit is contained in:
Explosion Bot 2017-10-30 10:03:08 +01:00
parent 256c7dac5a
commit 72aea8f105
3 changed files with 32 additions and 20 deletions

View File

@ -209,7 +209,7 @@ def test_doc_api_right_edge(en_tokenizer):
def test_doc_api_has_vector():
vocab = Vocab()
vocab.clear_vectors(2)
vocab.vectors.add('kitten', numpy.asarray([0., 2.], dtype='f'))
vocab.vectors.add('kitten', vector=numpy.asarray([0., 2.], dtype='f'))
doc = Doc(vocab, words=['kitten'])
assert doc.has_vector

View File

@ -73,8 +73,8 @@ def test_doc_token_api_is_properties(en_vocab):
def test_doc_token_api_vectors():
vocab = Vocab()
vocab.clear_vectors(2)
vocab.vectors.add('apples', numpy.asarray([0., 2.], dtype='f'))
vocab.vectors.add('oranges', numpy.asarray([0., 1.], dtype='f'))
vocab.vectors.add('apples', vector=numpy.asarray([0., 2.], dtype='f'))
vocab.vectors.add('oranges', vector=numpy.asarray([0., 1.], dtype='f'))
doc = Doc(vocab, words=['apples', 'oranges', 'oov'])
assert doc.has_vector

View File

@ -21,8 +21,10 @@ cdef class Vectors:
Vectors data is kept in the vectors.data attribute, which should be an
instance of numpy.ndarray (for CPU vectors) or cupy.ndarray
(for GPU vectors). `vectors.key2row` is a dictionary mapping word hashes to
rows in the vectors.data table. The array `vectors.keys` keeps the keys in
order, such that `keys[vectors.key2row[key]] == key`.
rows in the vectors.data table.
Multiple keys can be mapped to the same vector, so len(keys) may be greater
(but not smaller) than data.shape[0].
"""
cdef public object data
cdef readonly StringStore strings
@ -57,7 +59,7 @@ cdef class Vectors:
for i, string in enumerate(self.strings):
if i >= self.data.shape[0]:
break
self.add(self.strings[string], self.data[i])
self.add(self.strings[string], vector=self.data[i])
def __reduce__(self):
return (Vectors, (self.strings, self.data))
@ -114,27 +116,36 @@ cdef class Vectors:
key = self.strings[key]
return key in self.key2row
def add(self, key, vector=None):
"""Add a key to the table, optionally setting a vector value as well.
def add(self, key, *, vector=None, row=None):
"""Add a key to the table. Keys can be mapped to an existing vector
by setting `row`, or a new vector can be added.
key (unicode / int): The key to add.
vector (numpy.ndarray): An optional vector to add.
vector (numpy.ndarray / None): A vector to add for the key.
row (int / None): The row-number of a vector to map the key to.
"""
if row is not None and vector is not None:
raise ValueError("Only one of 'row' and 'vector' may be set")
if isinstance(key, basestring_):
key = self.strings.add(key)
if key not in self.key2row:
i = self.i
if i >= self.keys.shape[0]:
self.keys.resize((self.keys.shape[0]*2,))
self.data.resize((self.data.shape[0]*2, self.data.shape[1]))
if key in self.key2row and vector is not None:
row = self.key2row[key]
elif key in self.key2row and row is not None:
self.key2row[key] = row
elif key not in self.key2row:
if row is not None:
self.key2row[key] = row
else:
self.key2row[key] = self.i
row = self.i
if row >= self.keys.shape[0]:
self.keys.resize((row*2,))
self.data.resize((row*2, self.data.shape[1]))
self.keys[self.i] = key
self.i += 1
else:
i = self.key2row[key]
if vector is not None:
self.data[i] = vector
return i
self.data[row] = vector
return row
def items(self):
"""Iterate over `(string key, vector)` pairs, in order.
@ -143,7 +154,8 @@ cdef class Vectors:
"""
for i, key in enumerate(self.keys):
string = self.strings[key]
yield string, self.data[i]
row = self.key2row[key]
yield string, self.data[row]
@property
def shape(self):