Improve score_cats for use with multiple textcat components (#11820)

* add test for running evaluate on an nlp pipeline with two distinct textcat components

* cleanup

* merge dicts instead of overwrite

* don't add more labels to the given set

* Revert "merge dicts instead of overwrite"

This reverts commit 89bee0ed77.

* Switch tests to separate scorer keys rather than merged dicts

* Revert unrelated edits

* Switch textcat scorers to v2

* formatting

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
Sofie Van Landeghem 2023-01-09 11:43:48 +01:00 committed by GitHub
parent f1dcdefc8a
commit 6d03b04901
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 116 additions and 7 deletions

View File

@ -74,7 +74,7 @@ subword_features = true
default_config={ default_config={
"threshold": 0.5, "threshold": 0.5,
"model": DEFAULT_MULTI_TEXTCAT_MODEL, "model": DEFAULT_MULTI_TEXTCAT_MODEL,
"scorer": {"@scorers": "spacy.textcat_multilabel_scorer.v1"}, "scorer": {"@scorers": "spacy.textcat_multilabel_scorer.v2"},
}, },
default_score_weights={ default_score_weights={
"cats_score": 1.0, "cats_score": 1.0,
@ -120,7 +120,7 @@ def textcat_multilabel_score(examples: Iterable[Example], **kwargs) -> Dict[str,
) )
@registry.scorers("spacy.textcat_multilabel_scorer.v1") @registry.scorers("spacy.textcat_multilabel_scorer.v2")
def make_textcat_multilabel_scorer(): def make_textcat_multilabel_scorer():
return textcat_multilabel_score return textcat_multilabel_score

View File

@ -476,14 +476,12 @@ class Scorer:
f_per_type = {label: PRFScore() for label in labels} f_per_type = {label: PRFScore() for label in labels}
auc_per_type = {label: ROCAUCScore() for label in labels} auc_per_type = {label: ROCAUCScore() for label in labels}
labels = set(labels) labels = set(labels)
if labels:
for eg in examples:
labels.update(eg.predicted.cats.keys())
labels.update(eg.reference.cats.keys())
for example in examples: for example in examples:
# Through this loop, None in the gold_cats indicates missing label. # Through this loop, None in the gold_cats indicates missing label.
pred_cats = getter(example.predicted, attr) pred_cats = getter(example.predicted, attr)
pred_cats = {k: v for k, v in pred_cats.items() if k in labels}
gold_cats = getter(example.reference, attr) gold_cats = getter(example.reference, attr)
gold_cats = {k: v for k, v in gold_cats.items() if k in labels}
for label in labels: for label in labels:
pred_score = pred_cats.get(label, 0.0) pred_score = pred_cats.get(label, 0.0)

View File

@ -898,7 +898,11 @@ def test_textcat_multi_threshold():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"component_name,scorer", [("textcat", "spacy.textcat_scorer.v1")] "component_name,scorer",
[
("textcat", "spacy.textcat_scorer.v1"),
("textcat_multilabel", "spacy.textcat_multilabel_scorer.v1"),
],
) )
def test_textcat_legacy_scorers(component_name, scorer): def test_textcat_legacy_scorers(component_name, scorer):
"""Check that legacy scorers are registered and produce the expected score """Check that legacy scorers are registered and produce the expected score

View File

@ -3,6 +3,7 @@ import logging
from unittest import mock from unittest import mock
import pytest import pytest
from spacy.language import Language from spacy.language import Language
from spacy.scorer import Scorer
from spacy.tokens import Doc, Span from spacy.tokens import Doc, Span
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.training import Example from spacy.training import Example
@ -126,6 +127,112 @@ def test_evaluate_no_pipe(nlp):
nlp.evaluate([Example.from_dict(doc, annots)]) nlp.evaluate([Example.from_dict(doc, annots)])
def test_evaluate_textcat_multilabel(en_vocab):
"""Test that evaluate works with a multilabel textcat pipe."""
nlp = Language(en_vocab)
textcat_multilabel = nlp.add_pipe("textcat_multilabel")
for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"):
textcat_multilabel.add_label(label)
nlp.initialize()
annots = {"cats": {"FEATURE": 1.0, "QUESTION": 1.0}}
doc = nlp.make_doc("hello world")
example = Example.from_dict(doc, annots)
scores = nlp.evaluate([example])
labels = nlp.get_pipe("textcat_multilabel").labels
for label in labels:
assert scores["cats_f_per_type"].get(label) is not None
for key in example.reference.cats.keys():
if key not in labels:
assert scores["cats_f_per_type"].get(key) is None
def test_evaluate_multiple_textcat_final(en_vocab):
"""Test that evaluate evaluates the final textcat component in a pipeline
with more than one textcat or textcat_multilabel."""
nlp = Language(en_vocab)
textcat = nlp.add_pipe("textcat")
for label in ("POSITIVE", "NEGATIVE"):
textcat.add_label(label)
textcat_multilabel = nlp.add_pipe("textcat_multilabel")
for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"):
textcat_multilabel.add_label(label)
nlp.initialize()
annots = {
"cats": {
"POSITIVE": 1.0,
"NEGATIVE": 0.0,
"FEATURE": 1.0,
"QUESTION": 1.0,
"POSITIVE": 1.0,
"NEGATIVE": 0.0,
}
}
doc = nlp.make_doc("hello world")
example = Example.from_dict(doc, annots)
scores = nlp.evaluate([example])
# get the labels from the final pipe
labels = nlp.get_pipe(nlp.pipe_names[-1]).labels
for label in labels:
assert scores["cats_f_per_type"].get(label) is not None
for key in example.reference.cats.keys():
if key not in labels:
assert scores["cats_f_per_type"].get(key) is None
def test_evaluate_multiple_textcat_separate(en_vocab):
"""Test that evaluate can evaluate multiple textcat components separately
with custom scorers."""
def custom_textcat_score(examples, **kwargs):
scores = Scorer.score_cats(
examples,
"cats",
multi_label=False,
**kwargs,
)
return {f"custom_{k}": v for k, v in scores.items()}
@spacy.registry.scorers("test_custom_textcat_scorer")
def make_custom_textcat_scorer():
return custom_textcat_score
nlp = Language(en_vocab)
textcat = nlp.add_pipe(
"textcat",
config={"scorer": {"@scorers": "test_custom_textcat_scorer"}},
)
for label in ("POSITIVE", "NEGATIVE"):
textcat.add_label(label)
textcat_multilabel = nlp.add_pipe("textcat_multilabel")
for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"):
textcat_multilabel.add_label(label)
nlp.initialize()
annots = {
"cats": {
"POSITIVE": 1.0,
"NEGATIVE": 0.0,
"FEATURE": 1.0,
"QUESTION": 1.0,
"POSITIVE": 1.0,
"NEGATIVE": 0.0,
}
}
doc = nlp.make_doc("hello world")
example = Example.from_dict(doc, annots)
scores = nlp.evaluate([example])
# check custom scores for the textcat pipe
assert "custom_cats_f_per_type" in scores
labels = nlp.get_pipe("textcat").labels
assert set(scores["custom_cats_f_per_type"].keys()) == set(labels)
# check default scores for the textcat_multilabel pipe
assert "cats_f_per_type" in scores
labels = nlp.get_pipe("textcat_multilabel").labels
assert set(scores["cats_f_per_type"].keys()) == set(labels)
def vector_modification_pipe(doc): def vector_modification_pipe(doc):
doc.vector += 1 doc.vector += 1
return doc return doc