Update docstrings and simplify most_similar

This commit is contained in:
ines 2017-11-01 00:18:08 +01:00
parent ba2e6c8c6f
commit 2ad2f09d12

View File

@ -70,17 +70,18 @@ cdef class Vectors:
@property @property
def size(self): def size(self):
"""Return rows*dims""" """RETURNS (int): rows*dims"""
return self.data.shape[0] * self.data.shape[1] return self.data.shape[0] * self.data.shape[1]
@property @property
def is_full(self): def is_full(self):
"""Returns True if no keys are available for new keys.""" """RETURNS (bool): `True` if no slots are available for new keys."""
return len(self._unset) == 0 return len(self._unset) == 0
@property @property
def n_keys(self): def n_keys(self):
"""Returns True if no keys are available for new keys.""" """RETURNS (int) The number of keys in the table. Note that this is the
number of all keys, not just unique vectors."""
return len(self.key2row) return len(self.key2row)
def __reduce__(self): def __reduce__(self):
@ -198,9 +199,10 @@ cdef class Vectors:
"""Add a key to the table. Keys can be mapped to an existing vector """Add a key to the table. Keys can be mapped to an existing vector
by setting `row`, or a new vector can be added. by setting `row`, or a new vector can be added.
key (unicode / int): The key to add. key (int): The key to add.
vector (numpy.ndarray / None): A vector to add for the key. vector (ndarray / None): A vector to add for the key.
row (int / None): The row-number of a vector to map the key to. row (int / None): The row number of a vector to map the key to.
RETURNS (int): The row the vector was added to.
""" """
if row is None and key in self.key2row: if row is None and key in self.key2row:
row = self.key2row[key] row = self.key2row[key]
@ -216,17 +218,20 @@ cdef class Vectors:
self._unset.remove(row) self._unset.remove(row)
return row return row
def most_similar(self, queries, *, return_scores=False, return_rows=False, def most_similar(self, queries, *, batch_size=1024):
batch_size=1024): """For each of the given vectors, find the single entry most similar
'''For each of the given vectors, find the single entry most similar
to it, by cosine. to it, by cosine.
Queries are by vector. Results are returned as an array of keys, Queries are by vector. Results are returned as a `(keys, best_rows,
or a tuple of (keys, scores) if return_scores=True. If `queries` is scores)` tuple. If `queries` is large, the calculations are performed in
large, the calculations are performed in chunks, to avoid consuming chunks, to avoid consuming too much memory. You can set the `batch_size`
too much memory. You can set the `batch_size` to control the size/space to control the size/space trade-off during the calculations.
trade-off during the calculations.
''' queries (ndarray): An array with one or more vectors.
batch_size (int): The batch size to use.
RETURNS (tuple): The most similar entry as a `(keys, best_rows, scores)`
tuple.
"""
xp = get_array_module(self.data) xp = get_array_module(self.data)
vectors = self.data / xp.linalg.norm(self.data, axis=1, keepdims=True) vectors = self.data / xp.linalg.norm(self.data, axis=1, keepdims=True)
@ -244,14 +249,7 @@ cdef class Vectors:
best_rows[i:i+batch_size] = sims.argmax(axis=1) best_rows[i:i+batch_size] = sims.argmax(axis=1)
scores[i:i+batch_size] = sims.max(axis=1) scores[i:i+batch_size] = sims.max(axis=1)
keys = self.get_keys(best_rows) keys = self.get_keys(best_rows)
if return_rows and return_scores: return (keys, best_rows, scores)
return (keys, best_rows, scores)
elif return_rows:
return (keys, best_rows)
elif return_scores:
return (keys, scores)
else:
return keys
def from_glove(self, path): def from_glove(self, path):
"""Load GloVe vectors from a directory. Assumes binary format, """Load GloVe vectors from a directory. Assumes binary format,
@ -261,8 +259,7 @@ cdef class Vectors:
By default GloVe outputs 64-bit vectors. By default GloVe outputs 64-bit vectors.
path (unicode / Path): The path to load the GloVe vectors from. path (unicode / Path): The path to load the GloVe vectors from.
RETURNS: A `StringStore` object, holding the key-to-string mapping.
RETURNS: A StringStore object, holding the key-to-string mapping.
""" """
path = util.ensure_path(path) path = util.ensure_path(path)
width = None width = None