Improve vector handling

This commit is contained in:
Matthew Honnibal 2017-08-19 20:35:33 +02:00
parent ef87562741
commit 1157294434
2 changed files with 26 additions and 16 deletions

View File

@ -18,6 +18,7 @@ cdef class Vectors:
cdef readonly StringStore strings
cdef public object key2row
cdef public object keys
cdef public int i
def __init__(self, strings, data_or_width):
self.strings = StringStore()
@ -26,13 +27,12 @@ cdef class Vectors:
dtype='f')
else:
data = data_or_width
self.i = 0
self.data = data
self.key2row = {}
self.keys = np.ndarray((self.data.shape[0],), dtype='uint64')
for i, string in enumerate(strings):
key = self.strings.add(string)
self.key2row[key] = i
self.keys[i] = key
for string in strings:
self.add_key(string)
def __reduce__(self):
return (Vectors, (self.strings, self.data))
@ -56,21 +56,29 @@ cdef class Vectors:
yield from self.data
def __len__(self):
# TODO: Fix the quadratic behaviour here!
return max(self.key2row.values())
return self.i
def __contains__(self, key):
if isinstance(key, basestring_):
key = self.strings[key]
return key in self.key2row
def add_key(self, string, vector=None):
key = self.strings.add(string)
next_i = len(self) + 1
self.keys[next_i] = key
self.key2row[key] = next_i
def add(self, key, vector=None):
if isinstance(key, basestring_):
key = self.strings.add(key)
if key not in self.key2row:
i = self.i
if i >= self.keys.shape[0]:
self.keys.resize((self.keys.shape[0]*2,))
self.data.resize((self.data.shape[0]*2, self.data.shape[1]))
self.key2row[key] = self.i
self.keys[self.i] = key
self.i += 1
else:
i = self.key2row[key]
if vector is not None:
self.data[next_i] = vector
self.data[i] = vector
return i
def items(self):
for i, string in enumerate(self.strings):
@ -139,5 +147,5 @@ cdef class Vectors:
('strings', lambda b: self.strings.from_bytes(b)),
('vectors', deserialize_weights)
))
util.from_bytes(deserializers, exclude)
util.from_bytes(data, deserializers, exclude)
return self

View File

@ -246,11 +246,13 @@ cdef class Vocab:
def vectors_length(self):
return len(self.vectors)
def clear_vectors(self):
def clear_vectors(self, new_dim=None):
"""Drop the current vector table. Because all vectors must be the same
width, you have to call this to change the size of the vectors.
"""
raise NotImplementedError
if new_dim is None:
new_dim = self.vectors.data.shape[1]
self.vectors = Vectors(self.strings, new_dim)
def get_vector(self, orth):
"""Retrieve a vector for a word in the vocabulary.
@ -278,7 +280,7 @@ cdef class Vocab:
"""
if not isinstance(orth, basestring_):
orth = self.strings[orth]
self.vectors.add_key(orth, vector=vector)
self.vectors.add(orth, vector=vector)
def has_vector(self, orth):
"""Check whether a word has a vector. Returns False if no