Include noun chunks method when pickling Vocab

This commit is contained in:
Adriane Boyd 2021-02-12 13:27:46 +01:00
parent 26bf642afd
commit 5e47a54d29
2 changed files with 29 additions and 3 deletions

View File

@ -1,7 +1,9 @@
import pytest import pytest
import numpy import numpy
import srsly import srsly
from spacy.lang.en import English
from spacy.strings import StringStore from spacy.strings import StringStore
from spacy.tokens import Doc
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.attrs import NORM from spacy.attrs import NORM
@ -20,7 +22,10 @@ def test_pickle_string_store(text1, text2):
@pytest.mark.parametrize("text1,text2", [("dog", "cat")]) @pytest.mark.parametrize("text1,text2", [("dog", "cat")])
def test_pickle_vocab(text1, text2): def test_pickle_vocab(text1, text2):
vocab = Vocab(lex_attr_getters={int(NORM): lambda string: string[:-1]}) vocab = Vocab(
lex_attr_getters={int(NORM): lambda string: string[:-1]},
get_noun_chunks=English.Defaults.syntax_iterators.get("noun_chunks"),
)
vocab.set_vector("dog", numpy.ones((5,), dtype="f")) vocab.set_vector("dog", numpy.ones((5,), dtype="f"))
lex1 = vocab[text1] lex1 = vocab[text1]
lex2 = vocab[text2] lex2 = vocab[text2]
@ -34,4 +39,23 @@ def test_pickle_vocab(text1, text2):
assert unpickled[text2].norm == lex2.norm assert unpickled[text2].norm == lex2.norm
assert unpickled[text1].norm != unpickled[text2].norm assert unpickled[text1].norm != unpickled[text2].norm
assert unpickled.vectors is not None assert unpickled.vectors is not None
assert unpickled.get_noun_chunks is not None
assert list(vocab["dog"].vector) == [1.0, 1.0, 1.0, 1.0, 1.0] assert list(vocab["dog"].vector) == [1.0, 1.0, 1.0, 1.0, 1.0]
def test_pickle_doc(en_vocab):
words = ["a", "b", "c"]
deps = ["dep"] * len(words)
heads = [0] * len(words)
doc = Doc(
en_vocab,
words=words,
deps=deps,
heads=heads,
)
data = srsly.pickle_dumps(doc)
unpickled = srsly.pickle_loads(data)
assert [t.text for t in unpickled] == words
assert [t.dep_ for t in unpickled] == deps
assert [t.head.i for t in unpickled] == heads
assert list(doc.noun_chunks) == []

View File

@ -551,12 +551,13 @@ def pickle_vocab(vocab):
data_dir = vocab.data_dir data_dir = vocab.data_dir
lex_attr_getters = srsly.pickle_dumps(vocab.lex_attr_getters) lex_attr_getters = srsly.pickle_dumps(vocab.lex_attr_getters)
lookups = vocab.lookups lookups = vocab.lookups
get_noun_chunks = vocab.get_noun_chunks
return (unpickle_vocab, return (unpickle_vocab,
(sstore, vectors, morph, data_dir, lex_attr_getters, lookups)) (sstore, vectors, morph, data_dir, lex_attr_getters, lookups, get_noun_chunks))
def unpickle_vocab(sstore, vectors, morphology, data_dir, def unpickle_vocab(sstore, vectors, morphology, data_dir,
lex_attr_getters, lookups): lex_attr_getters, lookups, get_noun_chunks):
cdef Vocab vocab = Vocab() cdef Vocab vocab = Vocab()
vocab.vectors = vectors vocab.vectors = vectors
vocab.strings = sstore vocab.strings = sstore
@ -564,6 +565,7 @@ def unpickle_vocab(sstore, vectors, morphology, data_dir,
vocab.data_dir = data_dir vocab.data_dir = data_dir
vocab.lex_attr_getters = srsly.pickle_loads(lex_attr_getters) vocab.lex_attr_getters = srsly.pickle_loads(lex_attr_getters)
vocab.lookups = lookups vocab.lookups = lookups
vocab.get_noun_chunks = get_noun_chunks
return vocab return vocab