mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 02:06:31 +03:00
Update vectors.add() to allow setting keys to rows
This commit is contained in:
parent
256c7dac5a
commit
72aea8f105
|
@ -209,7 +209,7 @@ def test_doc_api_right_edge(en_tokenizer):
|
|||
def test_doc_api_has_vector():
|
||||
vocab = Vocab()
|
||||
vocab.clear_vectors(2)
|
||||
vocab.vectors.add('kitten', numpy.asarray([0., 2.], dtype='f'))
|
||||
vocab.vectors.add('kitten', vector=numpy.asarray([0., 2.], dtype='f'))
|
||||
doc = Doc(vocab, words=['kitten'])
|
||||
assert doc.has_vector
|
||||
|
||||
|
|
|
@ -73,8 +73,8 @@ def test_doc_token_api_is_properties(en_vocab):
|
|||
def test_doc_token_api_vectors():
|
||||
vocab = Vocab()
|
||||
vocab.clear_vectors(2)
|
||||
vocab.vectors.add('apples', numpy.asarray([0., 2.], dtype='f'))
|
||||
vocab.vectors.add('oranges', numpy.asarray([0., 1.], dtype='f'))
|
||||
vocab.vectors.add('apples', vector=numpy.asarray([0., 2.], dtype='f'))
|
||||
vocab.vectors.add('oranges', vector=numpy.asarray([0., 1.], dtype='f'))
|
||||
doc = Doc(vocab, words=['apples', 'oranges', 'oov'])
|
||||
assert doc.has_vector
|
||||
|
||||
|
|
|
@ -21,8 +21,10 @@ cdef class Vectors:
|
|||
Vectors data is kept in the vectors.data attribute, which should be an
|
||||
instance of numpy.ndarray (for CPU vectors) or cupy.ndarray
|
||||
(for GPU vectors). `vectors.key2row` is a dictionary mapping word hashes to
|
||||
rows in the vectors.data table. The array `vectors.keys` keeps the keys in
|
||||
order, such that `keys[vectors.key2row[key]] == key`.
|
||||
rows in the vectors.data table.
|
||||
|
||||
Multiple keys can be mapped to the same vector, so len(keys) may be greater
|
||||
(but not smaller) than data.shape[0].
|
||||
"""
|
||||
cdef public object data
|
||||
cdef readonly StringStore strings
|
||||
|
@ -57,7 +59,7 @@ cdef class Vectors:
|
|||
for i, string in enumerate(self.strings):
|
||||
if i >= self.data.shape[0]:
|
||||
break
|
||||
self.add(self.strings[string], self.data[i])
|
||||
self.add(self.strings[string], vector=self.data[i])
|
||||
|
||||
def __reduce__(self):
|
||||
return (Vectors, (self.strings, self.data))
|
||||
|
@ -114,27 +116,36 @@ cdef class Vectors:
|
|||
key = self.strings[key]
|
||||
return key in self.key2row
|
||||
|
||||
def add(self, key, vector=None):
|
||||
"""Add a key to the table, optionally setting a vector value as well.
|
||||
def add(self, key, *, vector=None, row=None):
|
||||
"""Add a key to the table. Keys can be mapped to an existing vector
|
||||
by setting `row`, or a new vector can be added.
|
||||
|
||||
key (unicode / int): The key to add.
|
||||
vector (numpy.ndarray): An optional vector to add.
|
||||
vector (numpy.ndarray / None): A vector to add for the key.
|
||||
row (int / None): The row-number of a vector to map the key to.
|
||||
"""
|
||||
if row is not None and vector is not None:
|
||||
raise ValueError("Only one of 'row' and 'vector' may be set")
|
||||
if isinstance(key, basestring_):
|
||||
key = self.strings.add(key)
|
||||
if key not in self.key2row:
|
||||
i = self.i
|
||||
if i >= self.keys.shape[0]:
|
||||
self.keys.resize((self.keys.shape[0]*2,))
|
||||
self.data.resize((self.data.shape[0]*2, self.data.shape[1]))
|
||||
self.key2row[key] = self.i
|
||||
if key in self.key2row and vector is not None:
|
||||
row = self.key2row[key]
|
||||
elif key in self.key2row and row is not None:
|
||||
self.key2row[key] = row
|
||||
elif key not in self.key2row:
|
||||
if row is not None:
|
||||
self.key2row[key] = row
|
||||
else:
|
||||
self.key2row[key] = self.i
|
||||
row = self.i
|
||||
if row >= self.keys.shape[0]:
|
||||
self.keys.resize((row*2,))
|
||||
self.data.resize((row*2, self.data.shape[1]))
|
||||
self.keys[self.i] = key
|
||||
self.i += 1
|
||||
else:
|
||||
i = self.key2row[key]
|
||||
if vector is not None:
|
||||
self.data[i] = vector
|
||||
return i
|
||||
self.data[row] = vector
|
||||
return row
|
||||
|
||||
def items(self):
|
||||
"""Iterate over `(string key, vector)` pairs, in order.
|
||||
|
@ -143,7 +154,8 @@ cdef class Vectors:
|
|||
"""
|
||||
for i, key in enumerate(self.keys):
|
||||
string = self.strings[key]
|
||||
yield string, self.data[i]
|
||||
row = self.key2row[key]
|
||||
yield string, self.data[row]
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
|
|
Loading…
Reference in New Issue
Block a user