Add test for multiple textcat components with separate scorers

This commit is contained in:
Adriane Boyd 2022-12-13 14:53:29 +01:00
parent ae919e9907
commit 27a784de8a

View File

@ -3,6 +3,7 @@ import logging
from unittest import mock from unittest import mock
import pytest import pytest
from spacy.language import Language from spacy.language import Language
from spacy.scorer import Scorer
from spacy.tokens import Doc, Span from spacy.tokens import Doc, Span
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.training import Example from spacy.training import Example
@ -180,6 +181,61 @@ def test_evaluate_multiple_textcat_final(en_vocab):
assert scores["cats_f_per_type"].get(key) is None assert scores["cats_f_per_type"].get(key) is None
def test_evaluate_multiple_textcat_separate(en_vocab):
"""Test that evaluate can evaluate multiple textcat components separately
with custom scorers."""
def custom_textcat_score(examples, **kwargs):
def custom_getter(doc, _):
return doc.cats
return Scorer.score_cats(
examples,
"custom_cats",
multi_label=False,
getter=custom_getter,
**kwargs,
)
@spacy.registry.scorers("test_custom_textcat_scorer")
def make_custom_textcat_scorer():
return custom_textcat_score
nlp = Language(en_vocab)
textcat = nlp.add_pipe(
"textcat",
config={"scorer": {"@scorers": "test_custom_textcat_scorer"}},
)
for label in ("POSITIVE", "NEGATIVE"):
textcat.add_label(label)
textcat_multilabel = nlp.add_pipe("textcat_multilabel")
for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"):
textcat_multilabel.add_label(label)
nlp.initialize()
annots = {
"cats": {
"POSITIVE": 1.0,
"NEGATIVE": 0.0,
"FEATURE": 1.0,
"QUESTION": 1.0,
"POSITIVE": 1.0,
"NEGATIVE": 0.0,
}
}
doc = nlp.make_doc("hello world")
example = Example.from_dict(doc, annots)
scores = nlp.evaluate([example])
# check custom scores for the textcat pipe
assert "custom_cats_f_per_type" in scores
labels = nlp.get_pipe("textcat").labels
assert set(scores["custom_cats_f_per_type"].keys()) == set(labels)
# check default scores for the textcat_multilabel pipe
assert "cats_f_per_type" in scores
labels = nlp.get_pipe("textcat_multilabel").labels
assert set(scores["cats_f_per_type"].keys()) == set(labels)
def vector_modification_pipe(doc): def vector_modification_pipe(doc):
doc.vector += 1 doc.vector += 1
return doc return doc