mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-06 05:10:21 +03:00
Restrict cats_score to provided labels
This commit is contained in:
parent
6f9d630f7e
commit
f7c6cf80be
|
@ -473,17 +473,17 @@ class Scorer:
|
||||||
threshold = 0.5 if multi_label else 0.0
|
threshold = 0.5 if multi_label else 0.0
|
||||||
if not multi_label:
|
if not multi_label:
|
||||||
threshold = 0.0
|
threshold = 0.0
|
||||||
|
labels = set(labels)
|
||||||
f_per_type = {label: PRFScore() for label in labels}
|
f_per_type = {label: PRFScore() for label in labels}
|
||||||
auc_per_type = {label: ROCAUCScore() for label in labels}
|
auc_per_type = {label: ROCAUCScore() for label in labels}
|
||||||
labels = set(labels)
|
|
||||||
if labels:
|
|
||||||
for eg in examples:
|
|
||||||
labels.update(eg.predicted.cats.keys())
|
|
||||||
labels.update(eg.reference.cats.keys())
|
|
||||||
for example in examples:
|
for example in examples:
|
||||||
# Through this loop, None in the gold_cats indicates missing label.
|
# Through this loop, None in the gold_cats indicates missing label.
|
||||||
pred_cats = getter(example.predicted, attr)
|
pred_cats = {
|
||||||
gold_cats = getter(example.reference, attr)
|
k: v for k, v in getter(example.predicted, attr).items() if k in labels
|
||||||
|
}
|
||||||
|
gold_cats = {
|
||||||
|
k: v for k, v in getter(example.reference, attr).items() if k in labels
|
||||||
|
}
|
||||||
|
|
||||||
for label in labels:
|
for label in labels:
|
||||||
pred_score = pred_cats.get(label, 0.0)
|
pred_score = pred_cats.get(label, 0.0)
|
||||||
|
|
|
@ -58,6 +58,29 @@ def nlp():
|
||||||
return nlp
|
return nlp
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def nlp_tcm():
|
||||||
|
nlp = Language(Vocab())
|
||||||
|
textcat_multilabel = nlp.add_pipe("textcat_multilabel")
|
||||||
|
for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"):
|
||||||
|
textcat_multilabel.add_label(label)
|
||||||
|
nlp.initialize()
|
||||||
|
return nlp
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def nlp_tc_tcm():
|
||||||
|
nlp = Language(Vocab())
|
||||||
|
textcat = nlp.add_pipe("textcat")
|
||||||
|
for label in ("POSITIVE", "NEGATIVE"):
|
||||||
|
textcat.add_label(label)
|
||||||
|
textcat_multilabel = nlp.add_pipe("textcat_multilabel")
|
||||||
|
for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"):
|
||||||
|
textcat_multilabel.add_label(label)
|
||||||
|
nlp.initialize()
|
||||||
|
return nlp
|
||||||
|
|
||||||
|
|
||||||
def test_language_update(nlp):
|
def test_language_update(nlp):
|
||||||
text = "hello world"
|
text = "hello world"
|
||||||
annots = {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}
|
annots = {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}
|
||||||
|
@ -126,6 +149,42 @@ def test_evaluate_no_pipe(nlp):
|
||||||
nlp.evaluate([Example.from_dict(doc, annots)])
|
nlp.evaluate([Example.from_dict(doc, annots)])
|
||||||
|
|
||||||
|
|
||||||
|
def test_evaluate_textcat_multilabel(nlp_tcm):
|
||||||
|
"""Test that evaluate works with a multilabel textcat pipe."""
|
||||||
|
text = "hello world"
|
||||||
|
annots = {"doc_annotation": {"cats": {"FEATURE": 1.0, "QUESTION": 1.0}}}
|
||||||
|
doc = Doc(nlp_tcm.vocab, words=text.split(" "))
|
||||||
|
example = Example.from_dict(doc, annots)
|
||||||
|
scores = nlp_tcm.evaluate([example])
|
||||||
|
labels = nlp_tcm.get_pipe("textcat_multilabel").labels
|
||||||
|
for label in labels:
|
||||||
|
assert scores["cats_f_per_type"].get(label) is not None
|
||||||
|
for key in example.reference.cats.keys():
|
||||||
|
if key not in labels:
|
||||||
|
assert scores["cats_f_per_type"].get(key) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_evaluate_multiple_textcat(nlp_tc_tcm):
|
||||||
|
"""Test that evaluate evaluates the final textcat component in a pipeline
|
||||||
|
with more than one textcat or textcat_multilabel."""
|
||||||
|
text = "hello world"
|
||||||
|
annots = {
|
||||||
|
"doc_annotation": {
|
||||||
|
"cats": {"FEATURE": 1.0, "QUESTION": 1.0, "POSITIVE": 1.0, "NEGATIVE": 0.0}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
doc = Doc(nlp_tc_tcm.vocab, words=text.split(" "))
|
||||||
|
example = Example.from_dict(doc, annots)
|
||||||
|
scores = nlp_tc_tcm.evaluate([example])
|
||||||
|
# get the labels from the final pipe
|
||||||
|
labels = nlp_tc_tcm.get_pipe(nlp_tc_tcm.pipe_names[-1]).labels
|
||||||
|
for label in labels:
|
||||||
|
assert scores["cats_f_per_type"].get(label) is not None
|
||||||
|
for key in example.reference.cats.keys():
|
||||||
|
if key not in labels:
|
||||||
|
assert scores["cats_f_per_type"].get(key) is None
|
||||||
|
|
||||||
|
|
||||||
def vector_modification_pipe(doc):
|
def vector_modification_pipe(doc):
|
||||||
doc.vector += 1
|
doc.vector += 1
|
||||||
return doc
|
return doc
|
||||||
|
|
Loading…
Reference in New Issue
Block a user