Add docstrings for spacy.vectors

This commit is contained in:
Matthew Honnibal 2017-10-01 22:10:33 +02:00
parent 158e177cae
commit 97c409b602

View File

@ -16,7 +16,16 @@ from .compat import basestring_
cdef class Vectors: cdef class Vectors:
'''Store, save and load word vectors.''' '''Store, save and load word vectors.
Vectors data is kept in the vectors.data attribute, which should be an
instance of numpy.ndarray (for CPU vectors)
or cupy.ndarray (for GPU vectors).
vectors.key2row is a dictionary mapping word hashes to rows
in the vectors.data table. The array `vectors.keys` keeps
the keys in order, such that keys[vectors.key2row[key]] == key.
'''
cdef public object data cdef public object data
cdef readonly StringStore strings cdef readonly StringStore strings
cdef public object key2row cdef public object key2row
@ -24,7 +33,36 @@ cdef class Vectors:
cdef public int i cdef public int i
def __init__(self, strings, data_or_width=0): def __init__(self, strings, data_or_width=0):
self.strings = StringStore() '''Create a new vector store.
To keep the vector table empty, pass data_or_width=0:
>>> empty_vectors = Vectors(StringStore())
To create the vector table, and add vectors one-by-one:
>>> my_vector_data = {
... 'dog': numpy.random.uniform(-1, 1, (300,)),
... 'cat': numpy.random.uniform(-1, 1, (300,)),
... 'orange': numpy.random.uniform(-1, 1, (300,)),
... }
>>> strings = StringStore()
>>> for word in my_vector_data.keys():
... strings.add(word)
>>> vectors = Vectors(strings, 300)
>>> for word in strings:
... vectors[word] = preset_vectors[word]
To set the vector values directly on initialization:
>>> my_vector_table = numpy.zeros((3, 300), dtype='f')
>>> strings = StringStore()
>>> for key in my_vectors.keys():
... strings.add(key)
>>> for i, word in enumerate(strings):
... my_vectors_table[i] = my_vectors[word]
>>> vectors = Vectors(strings, my_vector_table)
'''
if isinstance(data_or_width, int): if isinstance(data_or_width, int):
self.data = data = numpy.zeros((len(strings), data_or_width), self.data = data = numpy.zeros((len(strings), data_or_width),
dtype='f') dtype='f')
@ -39,6 +77,11 @@ cdef class Vectors:
return (Vectors, (self.strings, self.data)) return (Vectors, (self.strings, self.data))
def __getitem__(self, key): def __getitem__(self, key):
'''Get a vector by key. If key is a string, it is hashed
to an integer ID using the vectors.strings table.
If the integer key is not found in the table, a KeyError is raised.
'''
if isinstance(key, basestring): if isinstance(key, basestring):
key = self.strings[key] key = self.strings[key]
i = self.key2row[key] i = self.key2row[key]
@ -48,23 +91,30 @@ cdef class Vectors:
return self.data[i] return self.data[i]
def __setitem__(self, key, vector): def __setitem__(self, key, vector):
'''Set a vector for the given key. If key is a string, it is hashed
to an integer ID using the vectors.strings table.
'''
if isinstance(key, basestring): if isinstance(key, basestring):
key = self.strings.add(key) key = self.strings.add(key)
i = self.key2row[key] i = self.key2row[key]
self.data[i] = vector self.data[i] = vector
def __iter__(self): def __iter__(self):
'''Yield vectors from the table.'''
yield from self.data yield from self.data
def __len__(self): def __len__(self):
'''Return the number of vectors that have been assigned.'''
return self.i return self.i
def __contains__(self, key): def __contains__(self, key):
'''Check whether a key has a vector entry in the table.'''
if isinstance(key, basestring_): if isinstance(key, basestring_):
key = self.strings[key] key = self.strings[key]
return key in self.key2row return key in self.key2row
def add(self, key, vector=None): def add(self, key, vector=None):
'''Add a key to the table, optionally setting a vector value as well.'''
if isinstance(key, basestring_): if isinstance(key, basestring_):
key = self.strings.add(key) key = self.strings.add(key)
if key not in self.key2row: if key not in self.key2row:
@ -82,6 +132,7 @@ cdef class Vectors:
return i return i
def items(self): def items(self):
'''Iterate over (string key, vector) pairs, in order.'''
for i, string in enumerate(self.strings): for i, string in enumerate(self.strings):
yield string, self.data[i] yield string, self.data[i]