mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-07 13:44:55 +03:00
Add test for multiple textcat components with separate scorers
This commit is contained in:
parent
ae919e9907
commit
27a784de8a
|
@ -3,6 +3,7 @@ import logging
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
import pytest
|
import pytest
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
|
from spacy.scorer import Scorer
|
||||||
from spacy.tokens import Doc, Span
|
from spacy.tokens import Doc, Span
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
from spacy.training import Example
|
from spacy.training import Example
|
||||||
|
@ -180,6 +181,61 @@ def test_evaluate_multiple_textcat_final(en_vocab):
|
||||||
assert scores["cats_f_per_type"].get(key) is None
|
assert scores["cats_f_per_type"].get(key) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_evaluate_multiple_textcat_separate(en_vocab):
|
||||||
|
"""Test that evaluate can evaluate multiple textcat components separately
|
||||||
|
with custom scorers."""
|
||||||
|
|
||||||
|
def custom_textcat_score(examples, **kwargs):
|
||||||
|
def custom_getter(doc, _):
|
||||||
|
return doc.cats
|
||||||
|
|
||||||
|
return Scorer.score_cats(
|
||||||
|
examples,
|
||||||
|
"custom_cats",
|
||||||
|
multi_label=False,
|
||||||
|
getter=custom_getter,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@spacy.registry.scorers("test_custom_textcat_scorer")
|
||||||
|
def make_custom_textcat_scorer():
|
||||||
|
return custom_textcat_score
|
||||||
|
|
||||||
|
nlp = Language(en_vocab)
|
||||||
|
textcat = nlp.add_pipe(
|
||||||
|
"textcat",
|
||||||
|
config={"scorer": {"@scorers": "test_custom_textcat_scorer"}},
|
||||||
|
)
|
||||||
|
for label in ("POSITIVE", "NEGATIVE"):
|
||||||
|
textcat.add_label(label)
|
||||||
|
textcat_multilabel = nlp.add_pipe("textcat_multilabel")
|
||||||
|
for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"):
|
||||||
|
textcat_multilabel.add_label(label)
|
||||||
|
nlp.initialize()
|
||||||
|
|
||||||
|
annots = {
|
||||||
|
"cats": {
|
||||||
|
"POSITIVE": 1.0,
|
||||||
|
"NEGATIVE": 0.0,
|
||||||
|
"FEATURE": 1.0,
|
||||||
|
"QUESTION": 1.0,
|
||||||
|
"POSITIVE": 1.0,
|
||||||
|
"NEGATIVE": 0.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
doc = nlp.make_doc("hello world")
|
||||||
|
example = Example.from_dict(doc, annots)
|
||||||
|
scores = nlp.evaluate([example])
|
||||||
|
# check custom scores for the textcat pipe
|
||||||
|
assert "custom_cats_f_per_type" in scores
|
||||||
|
labels = nlp.get_pipe("textcat").labels
|
||||||
|
assert set(scores["custom_cats_f_per_type"].keys()) == set(labels)
|
||||||
|
# check default scores for the textcat_multilabel pipe
|
||||||
|
assert "cats_f_per_type" in scores
|
||||||
|
labels = nlp.get_pipe("textcat_multilabel").labels
|
||||||
|
assert set(scores["cats_f_per_type"].keys()) == set(labels)
|
||||||
|
|
||||||
|
|
||||||
def vector_modification_pipe(doc):
|
def vector_modification_pipe(doc):
|
||||||
doc.vector += 1
|
doc.vector += 1
|
||||||
return doc
|
return doc
|
||||||
|
|
Loading…
Reference in New Issue
Block a user