From 2791f0b552f2c684206310d2d66ba5cf80e55a36 Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Tue, 27 Dec 2022 19:16:59 +0900 Subject: [PATCH] Add test for merging pipelines --- spacy/tests/test_cli.py | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index 627c59b1e..7faab850a 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -20,7 +20,7 @@ from spacy.cli._util import parse_config_overrides, string_to_list from spacy.cli._util import substitute_project_variables from spacy.cli._util import validate_project_commands from spacy.cli._util import upload_file, download_file -from spacy.cli.configure import configure_resume_cli, use_tok2vec +from spacy.cli.configure import configure_resume_cli, use_tok2vec, merge_pipelines from spacy.cli.debug_data import _compile_gold, _get_labels_from_model from spacy.cli.debug_data import _get_labels_from_spancat from spacy.cli.debug_data import _get_distribution, _get_kl_divergence @@ -1216,3 +1216,38 @@ def test_use_tok2vec(tmp_path): assert out_path.exists(), "No model saved" assert "tok2vec" in conf["components"], "No tok2vec component" + + +def test_merge_pipelines(tmp_path): + + # width is a placeholder, since we won't actually train this + listener_config = { + "model": { + "tok2vec": {"@architectures": "spacy.Tok2VecListener.v1", "width": "0"} + } + } + # base pipeline + base = spacy.blank("en") + base.add_pipe("tok2vec") + base.add_pipe("ner", config=listener_config) + base_path = tmp_path / "merge_base" + base.to_disk(base_path) + + # added pipeline + added = spacy.blank("en") + added.add_pipe("tok2vec") + added.add_pipe("ner", config=listener_config) + added_path = tmp_path / "merge_added" + added.to_disk(added_path) + + # these should combine and not have a name collision + out_path = tmp_path / "merge_result" + merged = merge_pipelines(base_path, added_path, out_path) + + # will give a key error if not present + merged.get_pipe("ner") + merged.get_pipe("ner2") + + ner2_conf = merged.config["components"]["ner2"] + arch = ner2_conf["model"]["tok2vec"]["@architectures"] + assert arch == "spacy.HashEmbedCNN.v2", "Wrong arch - listener not replaced?"