Fix vector remapping

This commit is contained in:
Matthew Honnibal 2017-10-31 11:40:46 +01:00
parent 9c11ee4a1c
commit cb5217012f
3 changed files with 18 additions and 15 deletions

View File

@ -30,9 +30,13 @@ VECTORS_KEY = 'spacy_pretrained_vectors'
def cosine(vec1, vec2):
norm1 = (vec1**2).sum() ** 0.5
norm2 = (vec2**2).sum() ** 0.5
return vec1.dot(vec2) / (norm1 * norm2)
xp = get_array_module(vec1)
norm1 = xp.linalg.norm(vec1)
norm2 = xp.linalg.norm(vec2)
if norm1 == 0. or norm2 == 0.:
return 0
else:
return vec1.dot(vec2) / (norm1 * norm2)
@layerize

View File

@ -2,7 +2,7 @@
from __future__ import unicode_literals
import numpy
import pytest
from numpy.testing import assert_allclose
from ...vocab import Vocab
from ..._ml import cosine
@ -18,8 +18,6 @@ def test_vocab_add_vector():
assert list(cat.vector) == [1., 1., 1.]
dog = vocab[u'dog']
assert list(dog.vector) == [2., 2., 2.]
for lex in vocab:
print(lex.orth_)
def test_vocab_prune_vectors():
@ -27,7 +25,6 @@ def test_vocab_prune_vectors():
_ = vocab[u'cat']
_ = vocab[u'dog']
_ = vocab[u'kitten']
print(list(vocab.strings))
data = numpy.ndarray((5,3), dtype='f')
data[0] = 1.
data[1] = 2.
@ -35,9 +32,9 @@ def test_vocab_prune_vectors():
vocab.set_vector(u'cat', data[0])
vocab.set_vector(u'dog', data[1])
vocab.set_vector(u'kitten', data[2])
for lex in vocab:
print(lex.orth_)
remap = vocab.prune_vectors(2)
assert remap == {u'kitten': (u'cat', cosine(data[0], data[2]))}
#print(remap)
assert list(remap.keys()) == [u'kitten']
neighbour, similarity = remap.values()[0]
assert neighbour == u'cat'
assert_allclose(similarity, cosine(data[0], data[2]), atol=1e-6)

View File

@ -281,24 +281,26 @@ cdef class Vocab:
toss = self.vectors.data[nr_row:]
# Normalize the vectors, so cosine similarity is just dot product.
# Note we can't modify the ones we're keeping in-place...
keep = keep / (xp.linalg.norm(keep, axis=1, keepdims=True)+1e-8)
keep = keep / (xp.linalg.norm(keep, axis=1, keepdims=True)+1e-12)
keep = xp.ascontiguousarray(keep.T)
neighbours = xp.zeros((toss.shape[0],), dtype='i')
scores = xp.zeros((toss.shape[0],), dtype='f')
for i in range(0, toss.shape[0], batch_size):
batch = toss[i : i+batch_size]
batch /= xp.linalg.norm(batch, axis=1, keepdims=True)+1e-8
batch /= xp.linalg.norm(batch, axis=1, keepdims=True)+1e-12
sims = xp.dot(batch, keep)
matches = sims.argmax(axis=1)
neighbours[i:i+batch_size] = matches
scores[i:i+batch_size] = sims.max(axis=1)
for lex in self:
i2k = {i: key for key, i in self.vectors.key2row.items()}
remap = {}
for lex in list(self):
# If we're losing the vector for this word, map it to the nearest
# vector we're keeping.
if lex.rank >= nr_row:
lex.rank = neighbours[lex.rank-nr_row]
self.vectors.add(lex.orth, row=lex.rank)
remap[lex.orth_] = (i2k[lex.rank], scores[lex.rank])
remap[lex.orth_] = (self.strings[i2k[lex.rank]], scores[lex.rank])
for key, row in self.vectors.key2row.items():
if row >= nr_row:
self.vectors.key2row[key] = neighbours[row-nr_row]