mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-11 17:56:30 +03:00
Extend score_spans for overlapping & non-labeled spans (#7209)
* extend span scorer with consider_label and allow_overlap * unit test for spans y2x overlap * add score_spans unit test * docs for new fields in scorer.score_spans * rename to include_label * spell out if-else for clarity * rename to 'labeled' Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
parent
c362006cb9
commit
204c2f116b
|
@ -311,6 +311,8 @@ class Scorer:
|
|||
*,
|
||||
getter: Callable[[Doc, str], Iterable[Span]] = getattr,
|
||||
has_annotation: Optional[Callable[[Doc], bool]] = None,
|
||||
labeled: bool = True,
|
||||
allow_overlap: bool = False,
|
||||
**cfg,
|
||||
) -> Dict[str, Any]:
|
||||
"""Returns PRF scores for labeled spans.
|
||||
|
@ -323,6 +325,11 @@ class Scorer:
|
|||
has_annotation (Optional[Callable[[Doc], bool]]) should return whether a `Doc`
|
||||
has annotation for this `attr`. Docs without annotation are skipped for
|
||||
scoring purposes.
|
||||
labeled (bool): Whether or not to include label information in
|
||||
the evaluation. If set to 'False', two spans will be considered
|
||||
equal if their start and end match, irrespective of their label.
|
||||
allow_overlap (bool): Whether or not to allow overlapping spans.
|
||||
If set to 'False', the alignment will automatically resolve conflicts.
|
||||
RETURNS (Dict[str, Any]): A dictionary containing the PRF scores under
|
||||
the keys attr_p/r/f and the per-type PRF scores under attr_per_type.
|
||||
|
||||
|
@ -351,33 +358,42 @@ class Scorer:
|
|||
gold_spans = set()
|
||||
pred_spans = set()
|
||||
for span in getter(gold_doc, attr):
|
||||
if labeled:
|
||||
gold_span = (span.label_, span.start, span.end - 1)
|
||||
else:
|
||||
gold_span = (span.start, span.end - 1)
|
||||
gold_spans.add(gold_span)
|
||||
gold_per_type[span.label_].add((span.label_, span.start, span.end - 1))
|
||||
gold_per_type[span.label_].add(gold_span)
|
||||
pred_per_type = {label: set() for label in labels}
|
||||
for span in example.get_aligned_spans_x2y(getter(pred_doc, attr)):
|
||||
pred_spans.add((span.label_, span.start, span.end - 1))
|
||||
pred_per_type[span.label_].add((span.label_, span.start, span.end - 1))
|
||||
for span in example.get_aligned_spans_x2y(getter(pred_doc, attr), allow_overlap):
|
||||
if labeled:
|
||||
pred_span = (span.label_, span.start, span.end - 1)
|
||||
else:
|
||||
pred_span = (span.start, span.end - 1)
|
||||
pred_spans.add(pred_span)
|
||||
pred_per_type[span.label_].add(pred_span)
|
||||
# Scores per label
|
||||
if labeled:
|
||||
for k, v in score_per_type.items():
|
||||
if k in pred_per_type:
|
||||
v.score_set(pred_per_type[k], gold_per_type[k])
|
||||
# Score for all labels
|
||||
score.score_set(pred_spans, gold_spans)
|
||||
if len(score) > 0:
|
||||
return {
|
||||
f"{attr}_p": score.precision,
|
||||
f"{attr}_r": score.recall,
|
||||
f"{attr}_f": score.fscore,
|
||||
f"{attr}_per_type": {k: v.to_dict() for k, v in score_per_type.items()},
|
||||
}
|
||||
else:
|
||||
return {
|
||||
# Assemble final result
|
||||
final_scores = {
|
||||
f"{attr}_p": None,
|
||||
f"{attr}_r": None,
|
||||
f"{attr}_f": None,
|
||||
f"{attr}_per_type": None,
|
||||
}
|
||||
if labeled:
|
||||
final_scores[f"{attr}_per_type"] = None
|
||||
if len(score) > 0:
|
||||
final_scores[f"{attr}_p"] = score.precision
|
||||
final_scores[f"{attr}_r"] = score.recall
|
||||
final_scores[f"{attr}_f"] = score.fscore
|
||||
if labeled:
|
||||
final_scores[f"{attr}_per_type"] = {k: v.to_dict() for k, v in score_per_type.items()}
|
||||
return final_scores
|
||||
|
||||
@staticmethod
|
||||
def score_cats(
|
||||
|
|
|
@ -6,7 +6,7 @@ from spacy.training.iob_utils import offsets_to_biluo_tags
|
|||
from spacy.scorer import Scorer, ROCAUCScore, PRFScore
|
||||
from spacy.scorer import _roc_auc_score, _roc_curve
|
||||
from spacy.lang.en import English
|
||||
from spacy.tokens import Doc
|
||||
from spacy.tokens import Doc, Span
|
||||
|
||||
|
||||
test_las_apple = [
|
||||
|
@ -405,6 +405,51 @@ def test_roc_auc_score():
|
|||
_ = score.score # noqa: F841
|
||||
|
||||
|
||||
def test_score_spans():
|
||||
nlp = English()
|
||||
text = "This is just a random sentence."
|
||||
key = "my_spans"
|
||||
gold = nlp.make_doc(text)
|
||||
pred = nlp.make_doc(text)
|
||||
spans = []
|
||||
spans.append(gold.char_span(0, 4, label="PERSON"))
|
||||
spans.append(gold.char_span(0, 7, label="ORG"))
|
||||
spans.append(gold.char_span(8, 12, label="ORG"))
|
||||
gold.spans[key] = spans
|
||||
|
||||
def span_getter(doc, span_key):
|
||||
return doc.spans[span_key]
|
||||
|
||||
# Predict exactly the same, but overlapping spans will be discarded
|
||||
pred.spans[key] = spans
|
||||
eg = Example(pred, gold)
|
||||
scores = Scorer.score_spans([eg], attr=key, getter=span_getter)
|
||||
assert scores[f"{key}_p"] == 1.0
|
||||
assert scores[f"{key}_r"] < 1.0
|
||||
|
||||
# Allow overlapping, now both precision and recall should be 100%
|
||||
pred.spans[key] = spans
|
||||
eg = Example(pred, gold)
|
||||
scores = Scorer.score_spans([eg], attr=key, getter=span_getter, allow_overlap=True)
|
||||
assert scores[f"{key}_p"] == 1.0
|
||||
assert scores[f"{key}_r"] == 1.0
|
||||
|
||||
# Change the predicted labels
|
||||
new_spans = [Span(pred, span.start, span.end, label="WRONG") for span in spans]
|
||||
pred.spans[key] = new_spans
|
||||
eg = Example(pred, gold)
|
||||
scores = Scorer.score_spans([eg], attr=key, getter=span_getter, allow_overlap=True)
|
||||
assert scores[f"{key}_p"] == 0.0
|
||||
assert scores[f"{key}_r"] == 0.0
|
||||
assert f"{key}_per_type" in scores
|
||||
|
||||
# Discard labels from the evaluation
|
||||
scores = Scorer.score_spans([eg], attr=key, getter=span_getter, allow_overlap=True, labeled=False)
|
||||
assert scores[f"{key}_p"] == 1.0
|
||||
assert scores[f"{key}_r"] == 1.0
|
||||
assert f"{key}_per_type" not in scores
|
||||
|
||||
|
||||
def test_prf_score():
|
||||
cand = {"hi", "ho"}
|
||||
gold1 = {"yo", "hi"}
|
||||
|
|
|
@ -426,6 +426,29 @@ def test_aligned_spans_x2y(en_vocab, en_tokenizer):
|
|||
assert [(ent.start, ent.end) for ent in ents_x2y] == [(0, 2), (4, 6)]
|
||||
|
||||
|
||||
def test_aligned_spans_y2x_overlap(en_vocab, en_tokenizer):
|
||||
text = "I flew to San Francisco Valley"
|
||||
nlp = English()
|
||||
doc = nlp(text)
|
||||
# the reference doc has overlapping spans
|
||||
gold_doc = nlp.make_doc(text)
|
||||
spans = []
|
||||
prefix = "I flew to "
|
||||
spans.append(gold_doc.char_span(len(prefix), len(prefix + "San Francisco"), label="CITY"))
|
||||
spans.append(gold_doc.char_span(len(prefix), len(prefix + "San Francisco Valley"), label="VALLEY"))
|
||||
spans_key = "overlap_ents"
|
||||
gold_doc.spans[spans_key] = spans
|
||||
example = Example(doc, gold_doc)
|
||||
spans_gold = example.reference.spans[spans_key]
|
||||
assert [(ent.start, ent.end) for ent in spans_gold] == [(3, 5), (3, 6)]
|
||||
|
||||
# Ensure that 'get_aligned_spans_y2x' has the aligned entities correct
|
||||
spans_y2x_no_overlap = example.get_aligned_spans_y2x(spans_gold, allow_overlap=False)
|
||||
assert [(ent.start, ent.end) for ent in spans_y2x_no_overlap] == [(3, 5)]
|
||||
spans_y2x_overlap = example.get_aligned_spans_y2x(spans_gold, allow_overlap=True)
|
||||
assert [(ent.start, ent.end) for ent in spans_y2x_overlap] == [(3, 5), (3, 6)]
|
||||
|
||||
|
||||
def test_gold_ner_missing_tags(en_tokenizer):
|
||||
doc = en_tokenizer("I flew to Silicon Valley via London.")
|
||||
biluo_tags = [None, "O", "O", "B-LOC", "L-LOC", "O", "U-GPE", "O"]
|
||||
|
|
|
@ -213,17 +213,18 @@ cdef class Example:
|
|||
else:
|
||||
return [None] * len(self.x)
|
||||
|
||||
def get_aligned_spans_x2y(self, x_spans):
|
||||
return self._get_aligned_spans(self.y, x_spans, self.alignment.x2y)
|
||||
def get_aligned_spans_x2y(self, x_spans, allow_overlap=False):
|
||||
return self._get_aligned_spans(self.y, x_spans, self.alignment.x2y, allow_overlap)
|
||||
|
||||
def get_aligned_spans_y2x(self, y_spans):
|
||||
return self._get_aligned_spans(self.x, y_spans, self.alignment.y2x)
|
||||
def get_aligned_spans_y2x(self, y_spans, allow_overlap=False):
|
||||
return self._get_aligned_spans(self.x, y_spans, self.alignment.y2x, allow_overlap)
|
||||
|
||||
def _get_aligned_spans(self, doc, spans, align):
|
||||
def _get_aligned_spans(self, doc, spans, align, allow_overlap):
|
||||
seen = set()
|
||||
output = []
|
||||
for span in spans:
|
||||
indices = align[span.start : span.end].data.ravel()
|
||||
if not allow_overlap:
|
||||
indices = [idx for idx in indices if idx not in seen]
|
||||
if len(indices) >= 1:
|
||||
aligned_span = Span(doc, indices[0], indices[-1] + 1, label=span.label)
|
||||
|
@ -237,7 +238,7 @@ cdef class Example:
|
|||
def get_aligned_ner(self):
|
||||
if not self.y.has_annotation("ENT_IOB"):
|
||||
return [None] * len(self.x) # should this be 'missing' instead of 'None' ?
|
||||
x_ents = self.get_aligned_spans_y2x(self.y.ents)
|
||||
x_ents = self.get_aligned_spans_y2x(self.y.ents, allow_overlap=False)
|
||||
# Default to 'None' for missing values
|
||||
x_tags = offsets_to_biluo_tags(
|
||||
self.x,
|
||||
|
|
|
@ -57,9 +57,9 @@ see the [training format documentation](/api/data-formats#dict-input).
|
|||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------------------- |
|
||||
| -------------- | ----------------------------------------------------------------------------------- |
|
||||
| `predicted` | The document containing (partial) predictions. Cannot be `None`. ~~Doc~~ |
|
||||
| `example_dict` | `Dict[str, obj]` | The gold-standard annotations as a dictionary. Cannot be `None`. ~~Dict[str, Any]~~ |
|
||||
| `example_dict` | The gold-standard annotations as a dictionary. Cannot be `None`. ~~Dict[str, Any]~~ |
|
||||
| **RETURNS** | The newly constructed object. ~~Example~~ |
|
||||
|
||||
## Example.text {#text tag="property"}
|
||||
|
@ -212,8 +212,9 @@ align to the tokenization in [`Example.predicted`](/api/example#predicted).
|
|||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | ----------------------------------------------------------------------------- |
|
||||
| --------------- | -------------------------------------------------------------------------------------------- |
|
||||
| `y_spans` | `Span` objects aligned to the tokenization of `reference`. ~~Iterable[Span]~~ |
|
||||
| `allow_overlap` | Whether the resulting `Span` objects may overlap or not. Set to `False` by default. ~~bool~~ |
|
||||
| **RETURNS** | `Span` objects aligned to the tokenization of `predicted`. ~~List[Span]~~ |
|
||||
|
||||
## Example.get_aligned_spans_x2y {#get_aligned_spans_x2y tag="method"}
|
||||
|
@ -239,8 +240,9 @@ against the original gold-standard annotation.
|
|||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | ----------------------------------------------------------------------------- |
|
||||
| --------------- | -------------------------------------------------------------------------------------------- |
|
||||
| `x_spans` | `Span` objects aligned to the tokenization of `predicted`. ~~Iterable[Span]~~ |
|
||||
| `allow_overlap` | Whether the resulting `Span` objects may overlap or not. Set to `False` by default. ~~bool~~ |
|
||||
| **RETURNS** | `Span` objects aligned to the tokenization of `reference`. ~~List[Span]~~ |
|
||||
|
||||
## Example.to_dict {#to_dict tag="method"}
|
||||
|
|
|
@ -138,12 +138,14 @@ Returns PRF scores for labeled or unlabeled spans.
|
|||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ---------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| ---------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `examples` | The `Example` objects holding both the predictions and the correct gold-standard annotations. ~~Iterable[Example]~~ |
|
||||
| `attr` | The attribute to score. ~~str~~ |
|
||||
| _keyword-only_ | |
|
||||
| `getter` | Defaults to `getattr`. If provided, `getter(doc, attr)` should return the `Span` objects for an individual `Doc`. ~~Callable[[Doc, str], Iterable[Span]]~~ |
|
||||
| `has_annotation` | Defaults to `None`. If provided, `has_annotation(doc)` should return whether a `Doc` has annotation for this `attr`. Docs without annotation are skipped for scoring purposes. ~~Optional[Callable[[Doc], bool]]~~ |
|
||||
| `has_annotation` | Defaults to `None`. If provided, `has_annotation(doc)` should return whether a `Doc` has annotation for this `attr`. Docs without annotation are skipped for scoring purposes. ~~str~~ |
|
||||
| `labeled` | Defaults to `True`. If set to `False`, two spans will be considered equal if their start and end match, irrespective of their label. ~~bool~~ |
|
||||
| `allow_overlap` | Defaults to `False`. Whether or not to allow overlapping spans. If set to `False`, the alignment will automatically resolve conflicts. ~~bool~~ |
|
||||
| **RETURNS** | A dictionary containing the PRF scores under the keys `{attr}_p`, `{attr}_r`, `{attr}_f` and the per-type PRF scores under `{attr}_per_type`. ~~Dict[str, Union[float, Dict[str, float]]]~~ |
|
||||
|
||||
## Scorer.score_deps {#score_deps tag="staticmethod" new="3"}
|
||||
|
|
Loading…
Reference in New Issue
Block a user