Update tests

This commit is contained in:
Ines Montani 2021-01-29 19:38:09 +11:00
parent 325f47500d
commit bc089b693c

View File

@ -6,12 +6,13 @@ from spacy.pipeline.tok2vec import Tok2Vec, Tok2VecListener
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.tokens import Doc from spacy.tokens import Doc
from spacy.training import Example from spacy.training import Example
from spacy.training.initialize import init_nlp
from spacy import util from spacy import util
from spacy.lang.en import English from spacy.lang.en import English
from thinc.api import Config from thinc.api import Config
from numpy.testing import assert_equal from numpy.testing import assert_equal
from ..util import get_batch from ..util import get_batch, make_tempdir
def test_empty_doc(): def test_empty_doc():
@ -55,17 +56,17 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size):
assert doc_vec.shape == (len(doc), width) assert doc_vec.shape == (len(doc), width)
# fmt: off
@pytest.mark.parametrize( @pytest.mark.parametrize(
"width,embed_arch,embed_config,encode_arch,encode_config", "width,embed_arch,embed_config,encode_arch,encode_config",
# fmt: off
[ [
(8, MultiHashEmbed, {"rows": [100, 100], "attrs": ["SHAPE", "LOWER"], "include_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 2}), (8, MultiHashEmbed, {"rows": [100, 100], "attrs": ["SHAPE", "LOWER"], "include_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 2}),
(8, MultiHashEmbed, {"rows": [100, 20], "attrs": ["ORTH", "PREFIX"], "include_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 6}), (8, MultiHashEmbed, {"rows": [100, 20], "attrs": ["ORTH", "PREFIX"], "include_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 6}),
(8, CharacterEmbed, {"rows": 100, "nM": 64, "nC": 8, "include_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 3}), (8, CharacterEmbed, {"rows": 100, "nM": 64, "nC": 8, "include_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 3}),
(8, CharacterEmbed, {"rows": 100, "nM": 16, "nC": 2, "include_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 3}), (8, CharacterEmbed, {"rows": 100, "nM": 16, "nC": 2, "include_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 3}),
], ],
)
# fmt: on # fmt: on
)
def test_tok2vec_configs(width, embed_arch, embed_config, encode_arch, encode_config): def test_tok2vec_configs(width, embed_arch, embed_config, encode_arch, encode_config):
embed_config["width"] = width embed_config["width"] = width
encode_config["width"] = width encode_config["width"] = width
@ -196,8 +197,14 @@ def test_replace_listeners():
tagger = nlp.get_pipe("tagger") tagger = nlp.get_pipe("tagger")
assert isinstance(tagger.model.layers[0], Tok2VecListener) assert isinstance(tagger.model.layers[0], Tok2VecListener)
assert tok2vec.listener_map["tagger"][0] == tagger.model.layers[0] assert tok2vec.listener_map["tagger"][0] == tagger.model.layers[0]
assert nlp.config["components"]["tok2vec"]["model"]["@architectures"] == "spacy.Tok2Vec.v2" assert (
assert nlp.config["components"]["tagger"]["model"]["tok2vec"]["@architectures"] == "spacy.Tok2VecListener.v1" nlp.config["components"]["tok2vec"]["model"]["@architectures"]
== "spacy.Tok2Vec.v2"
)
assert (
nlp.config["components"]["tagger"]["model"]["tok2vec"]["@architectures"]
== "spacy.Tok2VecListener.v1"
)
nlp.replace_listeners("tok2vec", "tagger", ["model.tok2vec"]) nlp.replace_listeners("tok2vec", "tagger", ["model.tok2vec"])
assert not isinstance(tagger.model.layers[0], Tok2VecListener) assert not isinstance(tagger.model.layers[0], Tok2VecListener)
t2v_cfg = nlp.config["components"]["tok2vec"]["model"] t2v_cfg = nlp.config["components"]["tok2vec"]["model"]
@ -211,3 +218,96 @@ def test_replace_listeners():
nlp.replace_listeners("tok2vec", "tagger", ["model.yolo"]) nlp.replace_listeners("tok2vec", "tagger", ["model.yolo"])
with pytest.raises(ValueError): with pytest.raises(ValueError):
nlp.replace_listeners("tok2vec", "tagger", ["model.tok2vec", "model.yolo"]) nlp.replace_listeners("tok2vec", "tagger", ["model.tok2vec", "model.yolo"])
cfg_string_multi = """
[nlp]
lang = "en"
pipeline = ["tok2vec","tagger", "ner"]
[components]
[components.tagger]
factory = "tagger"
[components.tagger.model]
@architectures = "spacy.Tagger.v1"
nO = null
[components.tagger.model.tok2vec]
@architectures = "spacy.Tok2VecListener.v1"
width = ${components.tok2vec.model.encode.width}
[components.ner]
factory = "ner"
[components.ner.model]
@architectures = "spacy.TransitionBasedParser.v2"
[components.ner.model.tok2vec]
@architectures = "spacy.Tok2VecListener.v1"
width = ${components.tok2vec.model.encode.width}
[components.tok2vec]
factory = "tok2vec"
[components.tok2vec.model]
@architectures = "spacy.Tok2Vec.v2"
[components.tok2vec.model.embed]
@architectures = "spacy.MultiHashEmbed.v1"
width = ${components.tok2vec.model.encode.width}
rows = [2000, 1000, 1000, 1000]
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
include_static_vectors = false
[components.tok2vec.model.encode]
@architectures = "spacy.MaxoutWindowEncoder.v2"
width = 96
depth = 4
window_size = 1
maxout_pieces = 3
"""
def test_replace_listeners_from_config():
orig_config = Config().from_str(cfg_string_multi)
nlp = util.load_model_from_config(orig_config, auto_fill=True)
annots = {"tags": ["V", "Z"], "entities": [(0, 1, "A"), (1, 2, "B")]}
examples = [Example.from_dict(nlp.make_doc("x y"), annots)]
nlp.initialize(lambda: examples)
tok2vec = nlp.get_pipe("tok2vec")
tagger = nlp.get_pipe("tagger")
ner = nlp.get_pipe("ner")
assert tok2vec.listening_components == ["tagger", "ner"]
assert any(isinstance(node, Tok2VecListener) for node in ner.model.walk())
assert any(isinstance(node, Tok2VecListener) for node in tagger.model.walk())
with make_tempdir() as dir_path:
nlp.to_disk(dir_path)
base_model = str(dir_path)
new_config = {
"nlp": {"lang": "en", "pipeline": ["tok2vec", "tagger", "ner"]},
"components": {
"tok2vec": {"source": base_model},
"tagger": {
"source": base_model,
"replace_listeners": ["model.tok2vec"],
},
"ner": {"source": base_model},
},
}
new_nlp = util.load_model_from_config(new_config, auto_fill=True)
new_nlp.initialize(lambda: examples)
tok2vec = new_nlp.get_pipe("tok2vec")
tagger = new_nlp.get_pipe("tagger")
ner = new_nlp.get_pipe("ner")
assert tok2vec.listening_components == ["ner"]
assert any(isinstance(node, Tok2VecListener) for node in ner.model.walk())
assert not any(isinstance(node, Tok2VecListener) for node in tagger.model.walk())
t2v_cfg = new_nlp.config["components"]["tok2vec"]["model"]
assert t2v_cfg["@architectures"] == "spacy.Tok2Vec.v2"
assert new_nlp.config["components"]["tagger"]["model"]["tok2vec"] == t2v_cfg
assert (
new_nlp.config["components"]["ner"]["model"]["tok2vec"]["@architectures"]
== "spacy.Tok2VecListener.v1"
)