Add unit test assertions

This commit is contained in:
thomashacker 2023-03-27 14:05:38 +02:00
parent 44704a1275
commit 1da14748db

View File

@ -164,24 +164,6 @@ def test_pretraining_default():
assert "PretrainCharacters" in filled["pretraining"]["objective"]["@architectures"]
@pytest.mark.parametrize("objective", CHAR_OBJECTIVES)
def test_pretraining_last_model(objective):
"""Test that pretraining works with the character objective"""
config = Config().from_str(pretrain_string_listener)
config["pretraining"]["objective"] = objective
nlp = util.load_model_from_config(config, auto_fill=True, validate=False)
filled = nlp.config
pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH)
filled = pretrain_config.merge(filled)
with make_tempdir() as tmp_dir:
file_path = write_sample_jsonl(tmp_dir)
filled["paths"]["raw_text"] = file_path
filled = filled.interpolate()
assert filled["pretraining"]["component"] == "tok2vec"
pretrain(filled, tmp_dir)
assert Path(tmp_dir / "model_last.bin").exists()
@pytest.mark.parametrize("objective", CHAR_OBJECTIVES)
def test_pretraining_tok2vec_characters(objective):
"""Test that pretraining works with the character objective"""
@ -199,6 +181,7 @@ def test_pretraining_tok2vec_characters(objective):
pretrain(filled, tmp_dir)
assert Path(tmp_dir / "model0.bin").exists()
assert Path(tmp_dir / "model4.bin").exists()
assert Path(tmp_dir / "model_last.bin").exists()
assert not Path(tmp_dir / "model5.bin").exists()
@ -255,6 +238,7 @@ def test_pretraining_tagger_tok2vec(config):
pretrain(filled, tmp_dir)
assert Path(tmp_dir / "model0.bin").exists()
assert Path(tmp_dir / "model4.bin").exists()
assert Path(tmp_dir / "model_last.bin").exists()
assert not Path(tmp_dir / "model5.bin").exists()