Extend score_spans for overlapping & non-labeled spans (#7209)

* extend span scorer with consider_label and allow_overlap

* unit test for spans y2x overlap

* add score_spans unit test

* docs for new fields in scorer.score_spans

* rename to include_label

* spell out if-else for clarity

* rename to 'labeled'

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
Sofie Van Landeghem 2021-04-08 12:19:17 +02:00 committed by GitHub
parent c362006cb9
commit 204c2f116b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 139 additions and 50 deletions

View File

@ -311,6 +311,8 @@ class Scorer:
*,
getter: Callable[[Doc, str], Iterable[Span]] = getattr,
has_annotation: Optional[Callable[[Doc], bool]] = None,
labeled: bool = True,
allow_overlap: bool = False,
**cfg,
) -> Dict[str, Any]:
"""Returns PRF scores for labeled spans.
@ -323,6 +325,11 @@ class Scorer:
has_annotation (Optional[Callable[[Doc], bool]]) should return whether a `Doc`
has annotation for this `attr`. Docs without annotation are skipped for
scoring purposes.
labeled (bool): Whether or not to include label information in
the evaluation. If set to 'False', two spans will be considered
equal if their start and end match, irrespective of their label.
allow_overlap (bool): Whether or not to allow overlapping spans.
If set to 'False', the alignment will automatically resolve conflicts.
RETURNS (Dict[str, Any]): A dictionary containing the PRF scores under
the keys attr_p/r/f and the per-type PRF scores under attr_per_type.
@ -351,33 +358,42 @@ class Scorer:
gold_spans = set()
pred_spans = set()
for span in getter(gold_doc, attr):
if labeled:
gold_span = (span.label_, span.start, span.end - 1)
else:
gold_span = (span.start, span.end - 1)
gold_spans.add(gold_span)
gold_per_type[span.label_].add((span.label_, span.start, span.end - 1))
gold_per_type[span.label_].add(gold_span)
pred_per_type = {label: set() for label in labels}
for span in example.get_aligned_spans_x2y(getter(pred_doc, attr)):
pred_spans.add((span.label_, span.start, span.end - 1))
pred_per_type[span.label_].add((span.label_, span.start, span.end - 1))
for span in example.get_aligned_spans_x2y(getter(pred_doc, attr), allow_overlap):
if labeled:
pred_span = (span.label_, span.start, span.end - 1)
else:
pred_span = (span.start, span.end - 1)
pred_spans.add(pred_span)
pred_per_type[span.label_].add(pred_span)
# Scores per label
if labeled:
for k, v in score_per_type.items():
if k in pred_per_type:
v.score_set(pred_per_type[k], gold_per_type[k])
# Score for all labels
score.score_set(pred_spans, gold_spans)
if len(score) > 0:
return {
f"{attr}_p": score.precision,
f"{attr}_r": score.recall,
f"{attr}_f": score.fscore,
f"{attr}_per_type": {k: v.to_dict() for k, v in score_per_type.items()},
}
else:
return {
# Assemble final result
final_scores = {
f"{attr}_p": None,
f"{attr}_r": None,
f"{attr}_f": None,
f"{attr}_per_type": None,
}
if labeled:
final_scores[f"{attr}_per_type"] = None
if len(score) > 0:
final_scores[f"{attr}_p"] = score.precision
final_scores[f"{attr}_r"] = score.recall
final_scores[f"{attr}_f"] = score.fscore
if labeled:
final_scores[f"{attr}_per_type"] = {k: v.to_dict() for k, v in score_per_type.items()}
return final_scores
@staticmethod
def score_cats(

View File

@ -6,7 +6,7 @@ from spacy.training.iob_utils import offsets_to_biluo_tags
from spacy.scorer import Scorer, ROCAUCScore, PRFScore
from spacy.scorer import _roc_auc_score, _roc_curve
from spacy.lang.en import English
from spacy.tokens import Doc
from spacy.tokens import Doc, Span
test_las_apple = [
@ -405,6 +405,51 @@ def test_roc_auc_score():
_ = score.score # noqa: F841
def test_score_spans():
nlp = English()
text = "This is just a random sentence."
key = "my_spans"
gold = nlp.make_doc(text)
pred = nlp.make_doc(text)
spans = []
spans.append(gold.char_span(0, 4, label="PERSON"))
spans.append(gold.char_span(0, 7, label="ORG"))
spans.append(gold.char_span(8, 12, label="ORG"))
gold.spans[key] = spans
def span_getter(doc, span_key):
return doc.spans[span_key]
# Predict exactly the same, but overlapping spans will be discarded
pred.spans[key] = spans
eg = Example(pred, gold)
scores = Scorer.score_spans([eg], attr=key, getter=span_getter)
assert scores[f"{key}_p"] == 1.0
assert scores[f"{key}_r"] < 1.0
# Allow overlapping, now both precision and recall should be 100%
pred.spans[key] = spans
eg = Example(pred, gold)
scores = Scorer.score_spans([eg], attr=key, getter=span_getter, allow_overlap=True)
assert scores[f"{key}_p"] == 1.0
assert scores[f"{key}_r"] == 1.0
# Change the predicted labels
new_spans = [Span(pred, span.start, span.end, label="WRONG") for span in spans]
pred.spans[key] = new_spans
eg = Example(pred, gold)
scores = Scorer.score_spans([eg], attr=key, getter=span_getter, allow_overlap=True)
assert scores[f"{key}_p"] == 0.0
assert scores[f"{key}_r"] == 0.0
assert f"{key}_per_type" in scores
# Discard labels from the evaluation
scores = Scorer.score_spans([eg], attr=key, getter=span_getter, allow_overlap=True, labeled=False)
assert scores[f"{key}_p"] == 1.0
assert scores[f"{key}_r"] == 1.0
assert f"{key}_per_type" not in scores
def test_prf_score():
cand = {"hi", "ho"}
gold1 = {"yo", "hi"}

View File

@ -426,6 +426,29 @@ def test_aligned_spans_x2y(en_vocab, en_tokenizer):
assert [(ent.start, ent.end) for ent in ents_x2y] == [(0, 2), (4, 6)]
def test_aligned_spans_y2x_overlap(en_vocab, en_tokenizer):
text = "I flew to San Francisco Valley"
nlp = English()
doc = nlp(text)
# the reference doc has overlapping spans
gold_doc = nlp.make_doc(text)
spans = []
prefix = "I flew to "
spans.append(gold_doc.char_span(len(prefix), len(prefix + "San Francisco"), label="CITY"))
spans.append(gold_doc.char_span(len(prefix), len(prefix + "San Francisco Valley"), label="VALLEY"))
spans_key = "overlap_ents"
gold_doc.spans[spans_key] = spans
example = Example(doc, gold_doc)
spans_gold = example.reference.spans[spans_key]
assert [(ent.start, ent.end) for ent in spans_gold] == [(3, 5), (3, 6)]
# Ensure that 'get_aligned_spans_y2x' has the aligned entities correct
spans_y2x_no_overlap = example.get_aligned_spans_y2x(spans_gold, allow_overlap=False)
assert [(ent.start, ent.end) for ent in spans_y2x_no_overlap] == [(3, 5)]
spans_y2x_overlap = example.get_aligned_spans_y2x(spans_gold, allow_overlap=True)
assert [(ent.start, ent.end) for ent in spans_y2x_overlap] == [(3, 5), (3, 6)]
def test_gold_ner_missing_tags(en_tokenizer):
doc = en_tokenizer("I flew to Silicon Valley via London.")
biluo_tags = [None, "O", "O", "B-LOC", "L-LOC", "O", "U-GPE", "O"]

View File

@ -213,17 +213,18 @@ cdef class Example:
else:
return [None] * len(self.x)
def get_aligned_spans_x2y(self, x_spans):
return self._get_aligned_spans(self.y, x_spans, self.alignment.x2y)
def get_aligned_spans_x2y(self, x_spans, allow_overlap=False):
return self._get_aligned_spans(self.y, x_spans, self.alignment.x2y, allow_overlap)
def get_aligned_spans_y2x(self, y_spans):
return self._get_aligned_spans(self.x, y_spans, self.alignment.y2x)
def get_aligned_spans_y2x(self, y_spans, allow_overlap=False):
return self._get_aligned_spans(self.x, y_spans, self.alignment.y2x, allow_overlap)
def _get_aligned_spans(self, doc, spans, align):
def _get_aligned_spans(self, doc, spans, align, allow_overlap):
seen = set()
output = []
for span in spans:
indices = align[span.start : span.end].data.ravel()
if not allow_overlap:
indices = [idx for idx in indices if idx not in seen]
if len(indices) >= 1:
aligned_span = Span(doc, indices[0], indices[-1] + 1, label=span.label)
@ -237,7 +238,7 @@ cdef class Example:
def get_aligned_ner(self):
if not self.y.has_annotation("ENT_IOB"):
return [None] * len(self.x) # should this be 'missing' instead of 'None' ?
x_ents = self.get_aligned_spans_y2x(self.y.ents)
x_ents = self.get_aligned_spans_y2x(self.y.ents, allow_overlap=False)
# Default to 'None' for missing values
x_tags = offsets_to_biluo_tags(
self.x,

View File

@ -57,9 +57,9 @@ see the [training format documentation](/api/data-formats#dict-input).
> ```
| Name | Description |
| -------------- | ------------------------------------------------------------------------- |
| -------------- | ----------------------------------------------------------------------------------- |
| `predicted` | The document containing (partial) predictions. Cannot be `None`. ~~Doc~~ |
| `example_dict` | `Dict[str, obj]` | The gold-standard annotations as a dictionary. Cannot be `None`. ~~Dict[str, Any]~~ |
| `example_dict` | The gold-standard annotations as a dictionary. Cannot be `None`. ~~Dict[str, Any]~~ |
| **RETURNS** | The newly constructed object. ~~Example~~ |
## Example.text {#text tag="property"}
@ -212,8 +212,9 @@ align to the tokenization in [`Example.predicted`](/api/example#predicted).
> ```
| Name | Description |
| ----------- | ----------------------------------------------------------------------------- |
| --------------- | -------------------------------------------------------------------------------------------- |
| `y_spans` | `Span` objects aligned to the tokenization of `reference`. ~~Iterable[Span]~~ |
| `allow_overlap` | Whether the resulting `Span` objects may overlap or not. Set to `False` by default. ~~bool~~ |
| **RETURNS** | `Span` objects aligned to the tokenization of `predicted`. ~~List[Span]~~ |
## Example.get_aligned_spans_x2y {#get_aligned_spans_x2y tag="method"}
@ -239,8 +240,9 @@ against the original gold-standard annotation.
> ```
| Name | Description |
| ----------- | ----------------------------------------------------------------------------- |
| --------------- | -------------------------------------------------------------------------------------------- |
| `x_spans` | `Span` objects aligned to the tokenization of `predicted`. ~~Iterable[Span]~~ |
| `allow_overlap` | Whether the resulting `Span` objects may overlap or not. Set to `False` by default. ~~bool~~ |
| **RETURNS** | `Span` objects aligned to the tokenization of `reference`. ~~List[Span]~~ |
## Example.to_dict {#to_dict tag="method"}

View File

@ -138,12 +138,14 @@ Returns PRF scores for labeled or unlabeled spans.
> ```
| Name | Description |
| ---------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| ---------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `examples` | The `Example` objects holding both the predictions and the correct gold-standard annotations. ~~Iterable[Example]~~ |
| `attr` | The attribute to score. ~~str~~ |
| _keyword-only_ | |
| `getter` | Defaults to `getattr`. If provided, `getter(doc, attr)` should return the `Span` objects for an individual `Doc`. ~~Callable[[Doc, str], Iterable[Span]]~~ |
| `has_annotation` | Defaults to `None`. If provided, `has_annotation(doc)` should return whether a `Doc` has annotation for this `attr`. Docs without annotation are skipped for scoring purposes. ~~Optional[Callable[[Doc], bool]]~~ |
| `has_annotation` | Defaults to `None`. If provided, `has_annotation(doc)` should return whether a `Doc` has annotation for this `attr`. Docs without annotation are skipped for scoring purposes. ~~str~~ |
| `labeled` | Defaults to `True`. If set to `False`, two spans will be considered equal if their start and end match, irrespective of their label. ~~bool~~ |
| `allow_overlap` | Defaults to `False`. Whether or not to allow overlapping spans. If set to `False`, the alignment will automatically resolve conflicts. ~~bool~~ |
| **RETURNS** | A dictionary containing the PRF scores under the keys `{attr}_p`, `{attr}_r`, `{attr}_f` and the per-type PRF scores under `{attr}_per_type`. ~~Dict[str, Union[float, Dict[str, float]]]~~ |
## Scorer.score_deps {#score_deps tag="staticmethod" new="3"}