pull out parameter example data

This commit is contained in:
Peter Baumgartner 2023-02-23 09:56:07 -05:00
parent 35f22ba211
commit d37b2094f7

View File

@ -46,6 +46,47 @@ def test_benchmark_accuracy_alias():
) )
example_words_1 = ["I", "like", "cats"]
example_words_2 = ["I", "like", "dogs"]
example_lemmas_1 = ["I", "like", "cat"]
example_lemmas_2 = ["I", "like", "dog"]
example_tags = ["PRP", "VBP", "NNS"]
example_morphs = [
"Case=Nom|Number=Sing|Person=1|PronType=Prs",
"Tense=Pres|VerbForm=Fin",
"Number=Plur",
]
example_deps = ["nsubj", "ROOT", "dobj"]
example_pos = ["PRON", "VERB", "NOUN"]
example_ents = ["O", "O", "I-ANIMAL"]
example_spans = [(2, 3, "ANIMAL")]
TRAIN_EXAMPLE_1 = dict(
words=example_words_1,
lemmas=example_lemmas_1,
tags=example_tags,
morphs=example_morphs,
deps=example_deps,
heads=[1, 1, 1],
pos=example_pos,
ents=example_ents,
spans=example_spans,
cats={"CAT": 1.0, "DOG": 0.0},
)
TRAIN_EXAMPLE_2 = dict(
words=example_words_2,
lemmas=example_lemmas_2,
tags=example_tags,
morphs=example_morphs,
deps=example_deps,
heads=[1, 1, 1],
pos=example_pos,
ents=example_ents,
spans=example_spans,
cats={"CAT": 0.0, "DOG": 1.0},
)
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize( @pytest.mark.parametrize(
"component,examples", "component,examples",
@ -53,69 +94,50 @@ def test_benchmark_accuracy_alias():
( (
"tagger", "tagger",
[ [
dict(words=["I", "like", "cats"], tags=["PRP", "VBP", "NNS"]), TRAIN_EXAMPLE_1,
dict(words=["I", "like", "dogs"], tags=["PRP", "VBP", "NNS"]), TRAIN_EXAMPLE_2,
], ],
), ),
( (
"morphologizer", "morphologizer",
[ [
dict( TRAIN_EXAMPLE_1,
words=["I", "like", "cats"], TRAIN_EXAMPLE_2,
morphs=[
"Case=Nom|Number=Sing|Person=1|PronType=Prs",
"Tense=Pres|VerbForm=Fin",
"Number=Plur",
],
),
dict(
words=["I", "like", "dogs"],
morphs=[
"Case=Nom|Number=Sing|Person=1|PronType=Prs",
"Tense=Pres|VerbForm=Fin",
"Number=Plur",
],
),
], ],
), ),
( (
"trainable_lemmatizer", "trainable_lemmatizer",
[ [
dict(words=["I", "like", "cats"], lemmas=["I", "like", "cat"]), TRAIN_EXAMPLE_1,
dict(words=["I", "like", "dogs"], lemmas=["I", "like", "dog"]), TRAIN_EXAMPLE_2,
], ],
), ),
( (
"parser", "parser",
[ [
dict( TRAIN_EXAMPLE_1,
words=["I", "like", "cats", "."],
deps=["nsubj", "ROOT", "dobj", "punct"],
heads=[1, 1, 1, 1],
pos=["PRON", "VERB", "NOUN", "PUNCT"],
),
] ]
* 30, * 30,
), ),
( (
"ner", "ner",
[ [
dict(words=["I", "like", "cats"], ents=["O", "O", "I-ANIMAL"]), TRAIN_EXAMPLE_1,
dict(words=["I", "like", "dogs"], ents=["O", "O", "I-ANIMAL"]), TRAIN_EXAMPLE_2,
], ],
), ),
( (
"spancat", "spancat",
[ [
dict(words=["I", "like", "cats"], spans=[(2, 3, "ANIMAL")]), TRAIN_EXAMPLE_1,
dict(words=["I", "like", "dogs"], spans=[(2, 3, "ANIMAL")]), TRAIN_EXAMPLE_2,
], ],
), ),
( (
"textcat", "textcat",
[ [
dict(words=["I", "like", "cats"], cats={"CAT": 1.0, "DOG": 0.0}), TRAIN_EXAMPLE_1,
dict(words=["I", "like", "dogs"], cats={"CAT": 0.0, "DOG": 1.0}), TRAIN_EXAMPLE_2,
], ],
), ),
], ],
@ -136,7 +158,12 @@ def test_init_config_trainable(component, examples, en_vocab):
] ]
train_docs.append(doc) train_docs.append(doc)
else: else:
train_docs = [Doc(en_vocab, **example) for example in examples] train_docs = []
for example in examples:
# cats, spans are not valid kwargs for instantiating a Doc
example = {k: v for k, v in example.items() if k not in ("cats", "spans")}
doc = Doc(en_vocab, **example)
train_docs.append(doc)
with make_tempdir() as d_in: with make_tempdir() as d_in:
train_bin = DocBin(docs=train_docs) train_bin = DocBin(docs=train_docs)
@ -173,7 +200,7 @@ def test_init_config_trainable(component, examples, en_vocab):
assert Path(d_in / "model" / "model-last").exists() assert Path(d_in / "model" / "model-last").exists()
# @pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize( @pytest.mark.parametrize(
"component,examples", "component,examples",
[ [
@ -181,28 +208,20 @@ def test_init_config_trainable(component, examples, en_vocab):
"tagger,parser,morphologizer", "tagger,parser,morphologizer",
[ [
dict( dict(
words=["I", "like", "cats"], words=example_words_1,
tags=["PRP", "VBP", "NNS"], tags=example_tags,
morphs=[ morphs=example_morphs,
"Case=Nom|Number=Sing|Person=1|PronType=Prs", deps=example_deps,
"Tense=Pres|VerbForm=Fin",
"Number=Plur",
],
deps=["nsubj", "ROOT", "dobj"],
heads=[1, 1, 1], heads=[1, 1, 1],
pos=["PRON", "VERB", "NOUN"], pos=example_pos,
), ),
dict( dict(
words=["I", "like", "dogs"], words=example_words_2,
tags=["PRP", "VBP", "NNS"], tags=example_tags,
morphs=[ morphs=example_morphs,
"Case=Nom|Number=Sing|Person=1|PronType=Prs", deps=example_deps,
"Tense=Pres|VerbForm=Fin",
"Number=Plur",
],
deps=["nsubj", "ROOT", "dobj"],
heads=[1, 1, 1], heads=[1, 1, 1],
pos=["PRON", "VERB", "NOUN"], pos=example_pos,
), ),
] ]
* 15, * 15,