Add prune_vectors method to Vocab

This commit is contained in:
Matthew Honnibal 2017-10-30 17:59:43 +01:00
parent d0cf12c8c7
commit e026b29ea9

View File

@ -5,6 +5,7 @@ import numpy
import dill import dill
from collections import OrderedDict from collections import OrderedDict
from thinc.neural.util import get_array_module
from .lexeme cimport EMPTY_LEXEME from .lexeme cimport EMPTY_LEXEME
from .lexeme cimport Lexeme from .lexeme cimport Lexeme
from .strings cimport hash_string from .strings cimport hash_string
@ -247,6 +248,44 @@ cdef class Vocab:
width = self.vectors.data.shape[1] width = self.vectors.data.shape[1]
self.vectors = Vectors(self.strings, width=width) self.vectors = Vectors(self.strings, width=width)
def prune_vectors(self, nr_row, batch_size=1024):
"""Reduce the current vector table to `nr_row` unique entries. Words
mapped to the discarded vectors will be remapped to the closest vector
among those remaining.
For example, suppose the original table had vectors for the words:
['sat', 'cat', 'feline', 'reclined']. If we prune the vector table to,
two rows, we would discard the vectors for 'feline' and 'reclined'.
These words would then be remapped to the closest remaining vector
-- so "feline" would have the same vector as "cat", and "reclined"
would have the same vector as "sat".
The similarities are judged by cosine. The original vectors may
be large, so the cosines are calculated in minibatches, to reduce
memory usage.
"""
xp = get_array_module(self.vectors.data)
# Work in batches, to avoid memory problems.
keep = self.vectors.data[:nr_row]
toss = self.vectors.data[nr_row:]
# Normalize the vectors, so cosine similarity is just dot product.
# Note we can't modify the ones we're keeping in-place...
keep = keep / (xp.linalg.norm(keep)+1e-8)
keep = xp.ascontiguousarray(keep.T)
neighbours = xp.zeros((toss.shape[0],), dtype='i')
for i in range(0, toss.shape[0], batch_size):
batch = toss[i : i+batch_size]
batch /= xp.linalg.norm(batch)+1e-8
neighbours[i:i+batch_size] = xp.dot(batch, keep).argmax(axis=1)
for lex in self:
# If we're losing the vector for this word, map it to the nearest
# vector we're keeping.
if lex.rank >= nr_row:
lex.rank = neighbours[lex.rank-nr_row]
self.vectors.add(lex.orth, row=lex.rank)
# Make copy, to encourage the original table to be garbage collected.
self.vectors.data = xp.ascontiguousarray(self.vectors.data[:nr_row])
def get_vector(self, orth): def get_vector(self, orth):
"""Retrieve a vector for a word in the vocabulary. Words can be looked """Retrieve a vector for a word in the vocabulary. Words can be looked
up by string or int ID. If no vectors data is loaded, ValueError is up by string or int ID. If no vectors data is loaded, ValueError is