mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
Fix vector remapping
This commit is contained in:
parent
9c11ee4a1c
commit
cb5217012f
10
spacy/_ml.py
10
spacy/_ml.py
|
@ -30,9 +30,13 @@ VECTORS_KEY = 'spacy_pretrained_vectors'
|
|||
|
||||
|
||||
def cosine(vec1, vec2):
|
||||
norm1 = (vec1**2).sum() ** 0.5
|
||||
norm2 = (vec2**2).sum() ** 0.5
|
||||
return vec1.dot(vec2) / (norm1 * norm2)
|
||||
xp = get_array_module(vec1)
|
||||
norm1 = xp.linalg.norm(vec1)
|
||||
norm2 = xp.linalg.norm(vec2)
|
||||
if norm1 == 0. or norm2 == 0.:
|
||||
return 0
|
||||
else:
|
||||
return vec1.dot(vec2) / (norm1 * norm2)
|
||||
|
||||
|
||||
@layerize
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
from numpy.testing import assert_allclose
|
||||
from ...vocab import Vocab
|
||||
from ..._ml import cosine
|
||||
|
||||
|
@ -18,8 +18,6 @@ def test_vocab_add_vector():
|
|||
assert list(cat.vector) == [1., 1., 1.]
|
||||
dog = vocab[u'dog']
|
||||
assert list(dog.vector) == [2., 2., 2.]
|
||||
for lex in vocab:
|
||||
print(lex.orth_)
|
||||
|
||||
|
||||
def test_vocab_prune_vectors():
|
||||
|
@ -27,7 +25,6 @@ def test_vocab_prune_vectors():
|
|||
_ = vocab[u'cat']
|
||||
_ = vocab[u'dog']
|
||||
_ = vocab[u'kitten']
|
||||
print(list(vocab.strings))
|
||||
data = numpy.ndarray((5,3), dtype='f')
|
||||
data[0] = 1.
|
||||
data[1] = 2.
|
||||
|
@ -35,9 +32,9 @@ def test_vocab_prune_vectors():
|
|||
vocab.set_vector(u'cat', data[0])
|
||||
vocab.set_vector(u'dog', data[1])
|
||||
vocab.set_vector(u'kitten', data[2])
|
||||
for lex in vocab:
|
||||
print(lex.orth_)
|
||||
|
||||
remap = vocab.prune_vectors(2)
|
||||
assert remap == {u'kitten': (u'cat', cosine(data[0], data[2]))}
|
||||
#print(remap)
|
||||
assert list(remap.keys()) == [u'kitten']
|
||||
neighbour, similarity = remap.values()[0]
|
||||
assert neighbour == u'cat'
|
||||
assert_allclose(similarity, cosine(data[0], data[2]), atol=1e-6)
|
||||
|
|
|
@ -281,24 +281,26 @@ cdef class Vocab:
|
|||
toss = self.vectors.data[nr_row:]
|
||||
# Normalize the vectors, so cosine similarity is just dot product.
|
||||
# Note we can't modify the ones we're keeping in-place...
|
||||
keep = keep / (xp.linalg.norm(keep, axis=1, keepdims=True)+1e-8)
|
||||
keep = keep / (xp.linalg.norm(keep, axis=1, keepdims=True)+1e-12)
|
||||
keep = xp.ascontiguousarray(keep.T)
|
||||
neighbours = xp.zeros((toss.shape[0],), dtype='i')
|
||||
scores = xp.zeros((toss.shape[0],), dtype='f')
|
||||
for i in range(0, toss.shape[0], batch_size):
|
||||
batch = toss[i : i+batch_size]
|
||||
batch /= xp.linalg.norm(batch, axis=1, keepdims=True)+1e-8
|
||||
batch /= xp.linalg.norm(batch, axis=1, keepdims=True)+1e-12
|
||||
sims = xp.dot(batch, keep)
|
||||
matches = sims.argmax(axis=1)
|
||||
neighbours[i:i+batch_size] = matches
|
||||
scores[i:i+batch_size] = sims.max(axis=1)
|
||||
for lex in self:
|
||||
i2k = {i: key for key, i in self.vectors.key2row.items()}
|
||||
remap = {}
|
||||
for lex in list(self):
|
||||
# If we're losing the vector for this word, map it to the nearest
|
||||
# vector we're keeping.
|
||||
if lex.rank >= nr_row:
|
||||
lex.rank = neighbours[lex.rank-nr_row]
|
||||
self.vectors.add(lex.orth, row=lex.rank)
|
||||
remap[lex.orth_] = (i2k[lex.rank], scores[lex.rank])
|
||||
remap[lex.orth_] = (self.strings[i2k[lex.rank]], scores[lex.rank])
|
||||
for key, row in self.vectors.key2row.items():
|
||||
if row >= nr_row:
|
||||
self.vectors.key2row[key] = neighbours[row-nr_row]
|
||||
|
|
Loading…
Reference in New Issue
Block a user