Add use_transformer test

This commit is contained in:
Paul O'Leary McCann 2022-12-27 17:34:25 +09:00
parent f3a928cb4b
commit 836fd87b1e

View File

@ -20,7 +20,7 @@ from spacy.cli._util import parse_config_overrides, string_to_list
from spacy.cli._util import substitute_project_variables from spacy.cli._util import substitute_project_variables
from spacy.cli._util import validate_project_commands from spacy.cli._util import validate_project_commands
from spacy.cli._util import upload_file, download_file from spacy.cli._util import upload_file, download_file
from spacy.cli.configure import configure_resume_cli from spacy.cli.configure import configure_resume_cli, use_transformer
from spacy.cli.debug_data import _compile_gold, _get_labels_from_model from spacy.cli.debug_data import _compile_gold, _get_labels_from_model
from spacy.cli.debug_data import _get_labels_from_spancat from spacy.cli.debug_data import _get_labels_from_spancat
from spacy.cli.debug_data import _get_distribution, _get_kl_divergence from spacy.cli.debug_data import _get_distribution, _get_kl_divergence
@ -1202,3 +1202,16 @@ def test_configure_resume(tmp_path):
for comp, val in conf["components"].items(): for comp, val in conf["components"].items():
assert "source" in val, f"Non-sourced component: {comp}" assert "source" in val, f"Non-sourced component: {comp}"
def test_use_transformer(tmp_path):
nlp = spacy.blank("en")
nlp.add_pipe("tok2vec")
base_path = tmp_path / "tok2vec_sample"
nlp.to_disk(base_path)
out_path = tmp_path / "converted_to_trf"
conf = use_transformer(base_path, out_path)
assert out_path.exists(), "No model saved"
assert "transformer" in conf["components"], "No transformer component"