spaCy/spacy/tests/vocab/test_add_vectors.py

41 lines
1.1 KiB
Python
Raw Normal View History

2017-01-12 17:09:49 +03:00
# coding: utf-8
from __future__ import unicode_literals
2017-01-12 17:09:49 +03:00
import numpy
2017-10-31 13:40:46 +03:00
from numpy.testing import assert_allclose
2017-10-31 04:00:26 +03:00
from ...vocab import Vocab
from ..._ml import cosine
2017-10-31 04:00:26 +03:00
def test_vocab_add_vector():
vocab = Vocab()
data = numpy.ndarray((5,3), dtype='f')
data[0] = 1.
data[1] = 2.
vocab.set_vector(u'cat', data[0])
vocab.set_vector(u'dog', data[1])
cat = vocab[u'cat']
assert list(cat.vector) == [1., 1., 1.]
dog = vocab[u'dog']
assert list(dog.vector) == [2., 2., 2.]
def test_vocab_prune_vectors():
vocab = Vocab()
_ = vocab[u'cat']
_ = vocab[u'dog']
_ = vocab[u'kitten']
data = numpy.ndarray((5,3), dtype='f')
data[0] = 1.
data[1] = 2.
data[2] = 1.1
vocab.set_vector(u'cat', data[0])
vocab.set_vector(u'dog', data[1])
vocab.set_vector(u'kitten', data[2])
remap = vocab.prune_vectors(2)
2017-10-31 13:40:46 +03:00
assert list(remap.keys()) == [u'kitten']
2017-11-01 00:21:55 +03:00
neighbour, similarity = list(remap.values())[0]
2017-10-31 20:25:08 +03:00
assert neighbour == u'cat', remap
2017-10-31 13:40:46 +03:00
assert_allclose(similarity, cosine(data[0], data[2]), atol=1e-6)