mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
Update vectors.add() to allow setting keys to rows
This commit is contained in:
parent
256c7dac5a
commit
72aea8f105
|
@ -209,7 +209,7 @@ def test_doc_api_right_edge(en_tokenizer):
|
||||||
def test_doc_api_has_vector():
|
def test_doc_api_has_vector():
|
||||||
vocab = Vocab()
|
vocab = Vocab()
|
||||||
vocab.clear_vectors(2)
|
vocab.clear_vectors(2)
|
||||||
vocab.vectors.add('kitten', numpy.asarray([0., 2.], dtype='f'))
|
vocab.vectors.add('kitten', vector=numpy.asarray([0., 2.], dtype='f'))
|
||||||
doc = Doc(vocab, words=['kitten'])
|
doc = Doc(vocab, words=['kitten'])
|
||||||
assert doc.has_vector
|
assert doc.has_vector
|
||||||
|
|
||||||
|
|
|
@ -73,8 +73,8 @@ def test_doc_token_api_is_properties(en_vocab):
|
||||||
def test_doc_token_api_vectors():
|
def test_doc_token_api_vectors():
|
||||||
vocab = Vocab()
|
vocab = Vocab()
|
||||||
vocab.clear_vectors(2)
|
vocab.clear_vectors(2)
|
||||||
vocab.vectors.add('apples', numpy.asarray([0., 2.], dtype='f'))
|
vocab.vectors.add('apples', vector=numpy.asarray([0., 2.], dtype='f'))
|
||||||
vocab.vectors.add('oranges', numpy.asarray([0., 1.], dtype='f'))
|
vocab.vectors.add('oranges', vector=numpy.asarray([0., 1.], dtype='f'))
|
||||||
doc = Doc(vocab, words=['apples', 'oranges', 'oov'])
|
doc = Doc(vocab, words=['apples', 'oranges', 'oov'])
|
||||||
assert doc.has_vector
|
assert doc.has_vector
|
||||||
|
|
||||||
|
|
|
@ -21,8 +21,10 @@ cdef class Vectors:
|
||||||
Vectors data is kept in the vectors.data attribute, which should be an
|
Vectors data is kept in the vectors.data attribute, which should be an
|
||||||
instance of numpy.ndarray (for CPU vectors) or cupy.ndarray
|
instance of numpy.ndarray (for CPU vectors) or cupy.ndarray
|
||||||
(for GPU vectors). `vectors.key2row` is a dictionary mapping word hashes to
|
(for GPU vectors). `vectors.key2row` is a dictionary mapping word hashes to
|
||||||
rows in the vectors.data table. The array `vectors.keys` keeps the keys in
|
rows in the vectors.data table.
|
||||||
order, such that `keys[vectors.key2row[key]] == key`.
|
|
||||||
|
Multiple keys can be mapped to the same vector, so len(keys) may be greater
|
||||||
|
(but not smaller) than data.shape[0].
|
||||||
"""
|
"""
|
||||||
cdef public object data
|
cdef public object data
|
||||||
cdef readonly StringStore strings
|
cdef readonly StringStore strings
|
||||||
|
@ -57,7 +59,7 @@ cdef class Vectors:
|
||||||
for i, string in enumerate(self.strings):
|
for i, string in enumerate(self.strings):
|
||||||
if i >= self.data.shape[0]:
|
if i >= self.data.shape[0]:
|
||||||
break
|
break
|
||||||
self.add(self.strings[string], self.data[i])
|
self.add(self.strings[string], vector=self.data[i])
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
return (Vectors, (self.strings, self.data))
|
return (Vectors, (self.strings, self.data))
|
||||||
|
@ -114,27 +116,36 @@ cdef class Vectors:
|
||||||
key = self.strings[key]
|
key = self.strings[key]
|
||||||
return key in self.key2row
|
return key in self.key2row
|
||||||
|
|
||||||
def add(self, key, vector=None):
|
def add(self, key, *, vector=None, row=None):
|
||||||
"""Add a key to the table, optionally setting a vector value as well.
|
"""Add a key to the table. Keys can be mapped to an existing vector
|
||||||
|
by setting `row`, or a new vector can be added.
|
||||||
|
|
||||||
key (unicode / int): The key to add.
|
key (unicode / int): The key to add.
|
||||||
vector (numpy.ndarray): An optional vector to add.
|
vector (numpy.ndarray / None): A vector to add for the key.
|
||||||
|
row (int / None): The row-number of a vector to map the key to.
|
||||||
"""
|
"""
|
||||||
|
if row is not None and vector is not None:
|
||||||
|
raise ValueError("Only one of 'row' and 'vector' may be set")
|
||||||
if isinstance(key, basestring_):
|
if isinstance(key, basestring_):
|
||||||
key = self.strings.add(key)
|
key = self.strings.add(key)
|
||||||
if key not in self.key2row:
|
if key in self.key2row and vector is not None:
|
||||||
i = self.i
|
row = self.key2row[key]
|
||||||
if i >= self.keys.shape[0]:
|
elif key in self.key2row and row is not None:
|
||||||
self.keys.resize((self.keys.shape[0]*2,))
|
self.key2row[key] = row
|
||||||
self.data.resize((self.data.shape[0]*2, self.data.shape[1]))
|
elif key not in self.key2row:
|
||||||
|
if row is not None:
|
||||||
|
self.key2row[key] = row
|
||||||
|
else:
|
||||||
self.key2row[key] = self.i
|
self.key2row[key] = self.i
|
||||||
|
row = self.i
|
||||||
|
if row >= self.keys.shape[0]:
|
||||||
|
self.keys.resize((row*2,))
|
||||||
|
self.data.resize((row*2, self.data.shape[1]))
|
||||||
self.keys[self.i] = key
|
self.keys[self.i] = key
|
||||||
self.i += 1
|
self.i += 1
|
||||||
else:
|
|
||||||
i = self.key2row[key]
|
|
||||||
if vector is not None:
|
if vector is not None:
|
||||||
self.data[i] = vector
|
self.data[row] = vector
|
||||||
return i
|
return row
|
||||||
|
|
||||||
def items(self):
|
def items(self):
|
||||||
"""Iterate over `(string key, vector)` pairs, in order.
|
"""Iterate over `(string key, vector)` pairs, in order.
|
||||||
|
@ -143,7 +154,8 @@ cdef class Vectors:
|
||||||
"""
|
"""
|
||||||
for i, key in enumerate(self.keys):
|
for i, key in enumerate(self.keys):
|
||||||
string = self.strings[key]
|
string = self.strings[key]
|
||||||
yield string, self.data[i]
|
row = self.key2row[key]
|
||||||
|
yield string, self.data[row]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self):
|
def shape(self):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user