mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 16:22:29 +03:00
Add test for merging pipelines
This commit is contained in:
parent
10bbb01bb6
commit
2791f0b552
|
@ -20,7 +20,7 @@ from spacy.cli._util import parse_config_overrides, string_to_list
|
||||||
from spacy.cli._util import substitute_project_variables
|
from spacy.cli._util import substitute_project_variables
|
||||||
from spacy.cli._util import validate_project_commands
|
from spacy.cli._util import validate_project_commands
|
||||||
from spacy.cli._util import upload_file, download_file
|
from spacy.cli._util import upload_file, download_file
|
||||||
from spacy.cli.configure import configure_resume_cli, use_tok2vec
|
from spacy.cli.configure import configure_resume_cli, use_tok2vec, merge_pipelines
|
||||||
from spacy.cli.debug_data import _compile_gold, _get_labels_from_model
|
from spacy.cli.debug_data import _compile_gold, _get_labels_from_model
|
||||||
from spacy.cli.debug_data import _get_labels_from_spancat
|
from spacy.cli.debug_data import _get_labels_from_spancat
|
||||||
from spacy.cli.debug_data import _get_distribution, _get_kl_divergence
|
from spacy.cli.debug_data import _get_distribution, _get_kl_divergence
|
||||||
|
@ -1216,3 +1216,38 @@ def test_use_tok2vec(tmp_path):
|
||||||
assert out_path.exists(), "No model saved"
|
assert out_path.exists(), "No model saved"
|
||||||
|
|
||||||
assert "tok2vec" in conf["components"], "No tok2vec component"
|
assert "tok2vec" in conf["components"], "No tok2vec component"
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_pipelines(tmp_path):
|
||||||
|
|
||||||
|
# width is a placeholder, since we won't actually train this
|
||||||
|
listener_config = {
|
||||||
|
"model": {
|
||||||
|
"tok2vec": {"@architectures": "spacy.Tok2VecListener.v1", "width": "0"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
# base pipeline
|
||||||
|
base = spacy.blank("en")
|
||||||
|
base.add_pipe("tok2vec")
|
||||||
|
base.add_pipe("ner", config=listener_config)
|
||||||
|
base_path = tmp_path / "merge_base"
|
||||||
|
base.to_disk(base_path)
|
||||||
|
|
||||||
|
# added pipeline
|
||||||
|
added = spacy.blank("en")
|
||||||
|
added.add_pipe("tok2vec")
|
||||||
|
added.add_pipe("ner", config=listener_config)
|
||||||
|
added_path = tmp_path / "merge_added"
|
||||||
|
added.to_disk(added_path)
|
||||||
|
|
||||||
|
# these should combine and not have a name collision
|
||||||
|
out_path = tmp_path / "merge_result"
|
||||||
|
merged = merge_pipelines(base_path, added_path, out_path)
|
||||||
|
|
||||||
|
# will give a key error if not present
|
||||||
|
merged.get_pipe("ner")
|
||||||
|
merged.get_pipe("ner2")
|
||||||
|
|
||||||
|
ner2_conf = merged.config["components"]["ner2"]
|
||||||
|
arch = ner2_conf["model"]["tok2vec"]["@architectures"]
|
||||||
|
assert arch == "spacy.HashEmbedCNN.v2", "Wrong arch - listener not replaced?"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user