# coding: utf-8 from __future__ import unicode_literals import numpy from numpy.testing import assert_allclose from ...vocab import Vocab from ..._ml import cosine def test_vocab_add_vector(): vocab = Vocab() data = numpy.ndarray((5,3), dtype='f') data[0] = 1. data[1] = 2. vocab.set_vector(u'cat', data[0]) vocab.set_vector(u'dog', data[1]) cat = vocab[u'cat'] assert list(cat.vector) == [1., 1., 1.] dog = vocab[u'dog'] assert list(dog.vector) == [2., 2., 2.] def test_vocab_prune_vectors(): vocab = Vocab() _ = vocab[u'cat'] _ = vocab[u'dog'] _ = vocab[u'kitten'] data = numpy.ndarray((5,3), dtype='f') data[0] = 1. data[1] = 2. data[2] = 1.1 vocab.set_vector(u'cat', data[0]) vocab.set_vector(u'dog', data[1]) vocab.set_vector(u'kitten', data[2]) remap = vocab.prune_vectors(2) assert list(remap.keys()) == [u'kitten'] neighbour, similarity = list(remap.values())[0] assert neighbour == u'cat', remap assert_allclose(similarity, cosine(data[0], data[2]), atol=1e-6)