Fix (and test) vector pickling

This commit is contained in:
Matthew Honnibal 2017-12-07 09:53:30 +01:00
parent 2ae4755def
commit 36b47e3fa6
2 changed files with 5 additions and 0 deletions

View File

@ -3,6 +3,7 @@ from __future__ import unicode_literals
import pytest import pytest
import dill as pickle import dill as pickle
import numpy
from ..vocab import Vocab from ..vocab import Vocab
from ..attrs import NORM from ..attrs import NORM
@ -22,6 +23,7 @@ def test_pickle_string_store(stringstore, text1, text2):
@pytest.mark.parametrize('text1,text2', [('dog', 'cat')]) @pytest.mark.parametrize('text1,text2', [('dog', 'cat')])
def test_pickle_vocab(text1, text2): def test_pickle_vocab(text1, text2):
vocab = Vocab(lex_attr_getters={int(NORM): lambda string: string[:-1]}) vocab = Vocab(lex_attr_getters={int(NORM): lambda string: string[:-1]})
vocab.set_vector('dog', numpy.ones((5,), dtype='f'))
lex1 = vocab[text1] lex1 = vocab[text1]
lex2 = vocab[text2] lex2 = vocab[text2]
assert lex1.norm_ == text1[:-1] assert lex1.norm_ == text1[:-1]
@ -33,3 +35,5 @@ def test_pickle_vocab(text1, text2):
assert unpickled[text1].norm == lex1.norm assert unpickled[text1].norm == lex1.norm
assert unpickled[text2].norm == lex2.norm assert unpickled[text2].norm == lex2.norm
assert unpickled[text1].norm != unpickled[text2].norm assert unpickled[text1].norm != unpickled[text2].norm
assert unpickled.vectors is not None
assert list(vocab['dog'].vector) == [1.,1.,1.,1.,1.]

View File

@ -19,6 +19,7 @@ def unpickle_vectors(keys_and_rows, data):
vectors = Vectors(data=data) vectors = Vectors(data=data)
for key, row in keys_and_rows: for key, row in keys_and_rows:
vectors.add(key, row=row) vectors.add(key, row=row)
return vectors
cdef class Vectors: cdef class Vectors: