Fix Spancat test.

This commit is contained in:
Raphael Mitsch 2022-09-01 16:01:53 +02:00
parent 3a0a3854f7
commit 51863cd267

View File

@ -865,17 +865,17 @@ def test_cli_find_threshold(capsys):
for t in [
(
"I'm angry and confused in the Bank of America.",
"I am angry and confused in the Bank of America.",
{
"cats": {"ANGRY": 1.0, "CONFUSED": 1.0, "HAPPY": 0.0},
"spans": {"sc": [(7, 10, "ORG")]},
"spans": {"sc": [(31, 46, "ORG")]},
},
),
(
"I'm confused but happy in New York.",
"I am confused but happy in New York.",
{
"cats": {"ANGRY": 0.0, "CONFUSED": 1.0, "HAPPY": 1.0},
"spans": {"sc": [(6, 7, "GPE")]},
"spans": {"sc": [(27, 35, "GPE")]},
},
),
]:
@ -903,12 +903,9 @@ def test_cli_find_threshold(capsys):
if isinstance(comp, TextCategorizer):
for label in textcat_labels:
comp.add_label(label)
if isinstance(comp, SpanCategorizer):
comp.add_label("GPE")
comp.add_label("ORG")
_nlp.initialize()
_examples = make_examples(_nlp)
_nlp.initialize(get_examples=lambda: _examples)
for i in range(5):
_nlp.update(_examples)
@ -936,9 +933,8 @@ def test_cli_find_threshold(capsys):
== numpy.linspace(0, 1, 10)[1]
)
# todo fix spancat test
# Specifying name of non-MultiLabel_TextCategorizer component should fail.
nlp, _ = init_nlp((("spancat", {"spans_key": "sc", "threshold": 0.5}),))
nlp, _ = init_nlp((("spancat", {}),))
with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir)
assert (