Add test for merging pipelines

This commit is contained in:
Paul O'Leary McCann 2022-12-27 19:16:59 +09:00
parent 10bbb01bb6
commit 2791f0b552

View File

@ -20,7 +20,7 @@ from spacy.cli._util import parse_config_overrides, string_to_list
from spacy.cli._util import substitute_project_variables
from spacy.cli._util import validate_project_commands
from spacy.cli._util import upload_file, download_file
from spacy.cli.configure import configure_resume_cli, use_tok2vec
from spacy.cli.configure import configure_resume_cli, use_tok2vec, merge_pipelines
from spacy.cli.debug_data import _compile_gold, _get_labels_from_model
from spacy.cli.debug_data import _get_labels_from_spancat
from spacy.cli.debug_data import _get_distribution, _get_kl_divergence
@ -1216,3 +1216,38 @@ def test_use_tok2vec(tmp_path):
assert out_path.exists(), "No model saved"
assert "tok2vec" in conf["components"], "No tok2vec component"
def test_merge_pipelines(tmp_path):
# width is a placeholder, since we won't actually train this
listener_config = {
"model": {
"tok2vec": {"@architectures": "spacy.Tok2VecListener.v1", "width": "0"}
}
}
# base pipeline
base = spacy.blank("en")
base.add_pipe("tok2vec")
base.add_pipe("ner", config=listener_config)
base_path = tmp_path / "merge_base"
base.to_disk(base_path)
# added pipeline
added = spacy.blank("en")
added.add_pipe("tok2vec")
added.add_pipe("ner", config=listener_config)
added_path = tmp_path / "merge_added"
added.to_disk(added_path)
# these should combine and not have a name collision
out_path = tmp_path / "merge_result"
merged = merge_pipelines(base_path, added_path, out_path)
# will give a key error if not present
merged.get_pipe("ner")
merged.get_pipe("ner2")
ner2_conf = merged.config["components"]["ner2"]
arch = ner2_conf["model"]["tok2vec"]["@architectures"]
assert arch == "spacy.HashEmbedCNN.v2", "Wrong arch - listener not replaced?"