Fix Spancat test.

This commit is contained in:
Raphael Mitsch 2022-09-01 16:01:53 +02:00
parent 3a0a3854f7
commit 51863cd267

View File

@ -865,17 +865,17 @@ def test_cli_find_threshold(capsys):
for t in [ for t in [
( (
"I'm angry and confused in the Bank of America.", "I am angry and confused in the Bank of America.",
{ {
"cats": {"ANGRY": 1.0, "CONFUSED": 1.0, "HAPPY": 0.0}, "cats": {"ANGRY": 1.0, "CONFUSED": 1.0, "HAPPY": 0.0},
"spans": {"sc": [(7, 10, "ORG")]}, "spans": {"sc": [(31, 46, "ORG")]},
}, },
), ),
( (
"I'm confused but happy in New York.", "I am confused but happy in New York.",
{ {
"cats": {"ANGRY": 0.0, "CONFUSED": 1.0, "HAPPY": 1.0}, "cats": {"ANGRY": 0.0, "CONFUSED": 1.0, "HAPPY": 1.0},
"spans": {"sc": [(6, 7, "GPE")]}, "spans": {"sc": [(27, 35, "GPE")]},
}, },
), ),
]: ]:
@ -903,12 +903,9 @@ def test_cli_find_threshold(capsys):
if isinstance(comp, TextCategorizer): if isinstance(comp, TextCategorizer):
for label in textcat_labels: for label in textcat_labels:
comp.add_label(label) comp.add_label(label)
if isinstance(comp, SpanCategorizer):
comp.add_label("GPE")
comp.add_label("ORG")
_nlp.initialize()
_examples = make_examples(_nlp) _examples = make_examples(_nlp)
_nlp.initialize(get_examples=lambda: _examples)
for i in range(5): for i in range(5):
_nlp.update(_examples) _nlp.update(_examples)
@ -936,9 +933,8 @@ def test_cli_find_threshold(capsys):
== numpy.linspace(0, 1, 10)[1] == numpy.linspace(0, 1, 10)[1]
) )
# todo fix spancat test
# Specifying name of non-MultiLabel_TextCategorizer component should fail. # Specifying name of non-MultiLabel_TextCategorizer component should fail.
nlp, _ = init_nlp((("spancat", {"spans_key": "sc", "threshold": 0.5}),)) nlp, _ = init_nlp((("spancat", {}),))
with make_tempdir() as nlp_dir: with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir) nlp.to_disk(nlp_dir)
assert ( assert (