mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Fix vector pruning
This commit is contained in:
parent
e98451b5f7
commit
4112a991ec
|
@ -30,7 +30,8 @@ cdef class Vectors:
|
||||||
cdef readonly StringStore strings
|
cdef readonly StringStore strings
|
||||||
cdef public object key2row
|
cdef public object key2row
|
||||||
cdef public object keys
|
cdef public object keys
|
||||||
cdef public int i
|
cdef public int _i_key
|
||||||
|
cdef public int _i_vec
|
||||||
|
|
||||||
def __init__(self, strings, width=0, data=None):
|
def __init__(self, strings, width=0, data=None):
|
||||||
"""Create a new vector store. To keep the vector table empty, pass
|
"""Create a new vector store. To keep the vector table empty, pass
|
||||||
|
@ -53,7 +54,8 @@ cdef class Vectors:
|
||||||
self.data = numpy.asarray(data, dtype='f')
|
self.data = numpy.asarray(data, dtype='f')
|
||||||
else:
|
else:
|
||||||
self.data = numpy.zeros((len(self.strings), width), dtype='f')
|
self.data = numpy.zeros((len(self.strings), width), dtype='f')
|
||||||
self.i = 0
|
self._i_key = 0
|
||||||
|
self._i_vec = 0
|
||||||
self.key2row = {}
|
self.key2row = {}
|
||||||
self.keys = numpy.zeros((self.data.shape[0],), dtype='uint64')
|
self.keys = numpy.zeros((self.data.shape[0],), dtype='uint64')
|
||||||
if data is not None:
|
if data is not None:
|
||||||
|
@ -105,7 +107,7 @@ cdef class Vectors:
|
||||||
|
|
||||||
RETURNS (int): The number of vectors in the data.
|
RETURNS (int): The number of vectors in the data.
|
||||||
"""
|
"""
|
||||||
return self.i
|
return self._i_vec
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
"""Check whether a key has a vector entry in the table.
|
"""Check whether a key has a vector entry in the table.
|
||||||
|
@ -127,20 +129,20 @@ cdef class Vectors:
|
||||||
"""
|
"""
|
||||||
if isinstance(key, basestring_):
|
if isinstance(key, basestring_):
|
||||||
key = self.strings.add(key)
|
key = self.strings.add(key)
|
||||||
if key in self.key2row and row is None:
|
if row is None and key in self.key2row:
|
||||||
row = self.key2row[key]
|
row = self.key2row[key]
|
||||||
elif key in self.key2row and row is not None:
|
|
||||||
self.key2row[key] = row
|
|
||||||
elif row is None:
|
elif row is None:
|
||||||
row = self.i
|
row = self._i_vec
|
||||||
self.i += 1
|
self._i_vec += 1
|
||||||
if row >= self.keys.shape[0]:
|
if row >= self.data.shape[0]:
|
||||||
self.keys.resize((row*2,))
|
|
||||||
self.data.resize((row*2, self.data.shape[1]))
|
self.data.resize((row*2, self.data.shape[1]))
|
||||||
self.keys[row] = key
|
if key not in self.key2row:
|
||||||
|
if self._i_key >= self.keys.shape[0]:
|
||||||
|
self.keys.resize((self._i_key*2,))
|
||||||
|
self.keys[self._i_key] = key
|
||||||
|
self._i_key += 1
|
||||||
|
|
||||||
self.key2row[key] = row
|
self.key2row[key] = row
|
||||||
self.keys[row] = key
|
|
||||||
if vector is not None:
|
if vector is not None:
|
||||||
self.data[row] = vector
|
self.data[row] = vector
|
||||||
return row
|
return row
|
||||||
|
|
|
@ -248,7 +248,7 @@ cdef class Vocab:
|
||||||
width = self.vectors.data.shape[1]
|
width = self.vectors.data.shape[1]
|
||||||
self.vectors = Vectors(self.strings, width=width)
|
self.vectors = Vectors(self.strings, width=width)
|
||||||
|
|
||||||
def prune_vectors(self, nr_row, batch_size=1024):
|
def prune_vectors(self, nr_row, batch_size=8):
|
||||||
"""Reduce the current vector table to `nr_row` unique entries. Words
|
"""Reduce the current vector table to `nr_row` unique entries. Words
|
||||||
mapped to the discarded vectors will be remapped to the closest vector
|
mapped to the discarded vectors will be remapped to the closest vector
|
||||||
among those remaining.
|
among those remaining.
|
||||||
|
@ -267,22 +267,31 @@ cdef class Vocab:
|
||||||
xp = get_array_module(self.vectors.data)
|
xp = get_array_module(self.vectors.data)
|
||||||
# Work in batches, to avoid memory problems.
|
# Work in batches, to avoid memory problems.
|
||||||
keep = self.vectors.data[:nr_row]
|
keep = self.vectors.data[:nr_row]
|
||||||
|
keep_keys = [key for key, row in self.vectors.key2row.items() if row < nr_row]
|
||||||
toss = self.vectors.data[nr_row:]
|
toss = self.vectors.data[nr_row:]
|
||||||
# Normalize the vectors, so cosine similarity is just dot product.
|
# Normalize the vectors, so cosine similarity is just dot product.
|
||||||
# Note we can't modify the ones we're keeping in-place...
|
# Note we can't modify the ones we're keeping in-place...
|
||||||
keep = keep / (xp.linalg.norm(keep)+1e-8)
|
keep = keep / (xp.linalg.norm(keep, axis=1, keepdims=True)+1e-8)
|
||||||
keep = xp.ascontiguousarray(keep.T)
|
keep = xp.ascontiguousarray(keep.T)
|
||||||
neighbours = xp.zeros((toss.shape[0],), dtype='i')
|
neighbours = xp.zeros((toss.shape[0],), dtype='i')
|
||||||
|
scores = xp.zeros((toss.shape[0],), dtype='f')
|
||||||
for i in range(0, toss.shape[0], batch_size):
|
for i in range(0, toss.shape[0], batch_size):
|
||||||
batch = toss[i : i+batch_size]
|
batch = toss[i : i+batch_size]
|
||||||
batch /= xp.linalg.norm(batch)+1e-8
|
batch /= xp.linalg.norm(batch, axis=1, keepdims=True)+1e-8
|
||||||
neighbours[i:i+batch_size] = xp.dot(batch, keep).argmax(axis=1)
|
sims = xp.dot(batch, keep)
|
||||||
|
matches = sims.argmax(axis=1)
|
||||||
|
neighbours[i:i+batch_size] = matches
|
||||||
|
scores[i:i+batch_size] = sims.max(axis=1)
|
||||||
for lex in self:
|
for lex in self:
|
||||||
# If we're losing the vector for this word, map it to the nearest
|
# If we're losing the vector for this word, map it to the nearest
|
||||||
# vector we're keeping.
|
# vector we're keeping.
|
||||||
if lex.rank >= nr_row:
|
if lex.rank >= nr_row:
|
||||||
lex.rank = neighbours[lex.rank-nr_row]
|
lex.rank = neighbours[lex.rank-nr_row]
|
||||||
self.vectors.add(lex.orth, row=lex.rank)
|
self.vectors.add(lex.orth, row=lex.rank)
|
||||||
|
for key in self.vectors.keys:
|
||||||
|
row = self.vectors.key2row[key]
|
||||||
|
if row >= nr_row:
|
||||||
|
self.vectors.key2row[key] = neighbours[row-nr_row]
|
||||||
# Make copy, to encourage the original table to be garbage collected.
|
# Make copy, to encourage the original table to be garbage collected.
|
||||||
self.vectors.data = xp.ascontiguousarray(self.vectors.data[:nr_row])
|
self.vectors.data = xp.ascontiguousarray(self.vectors.data[:nr_row])
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user