From 204c2f116bd74c9a54d045742a33591fb36fb6d9 Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Thu, 8 Apr 2021 12:19:17 +0200 Subject: [PATCH] Extend score_spans for overlapping & non-labeled spans (#7209) * extend span scorer with consider_label and allow_overlap * unit test for spans y2x overlap * add score_spans unit test * docs for new fields in scorer.score_spans * rename to include_label * spell out if-else for clarity * rename to 'labeled' Co-authored-by: Adriane Boyd --- spacy/scorer.py | 52 +++++++++++++++++---------- spacy/tests/test_scorer.py | 49 +++++++++++++++++++++++-- spacy/tests/training/test_training.py | 23 ++++++++++++ spacy/training/example.pyx | 15 ++++---- website/docs/api/example.md | 32 +++++++++-------- website/docs/api/scorer.md | 18 +++++----- 6 files changed, 139 insertions(+), 50 deletions(-) diff --git a/spacy/scorer.py b/spacy/scorer.py index 8061aa329..25df44f14 100644 --- a/spacy/scorer.py +++ b/spacy/scorer.py @@ -311,6 +311,8 @@ class Scorer: *, getter: Callable[[Doc, str], Iterable[Span]] = getattr, has_annotation: Optional[Callable[[Doc], bool]] = None, + labeled: bool = True, + allow_overlap: bool = False, **cfg, ) -> Dict[str, Any]: """Returns PRF scores for labeled spans. @@ -323,6 +325,11 @@ class Scorer: has_annotation (Optional[Callable[[Doc], bool]]) should return whether a `Doc` has annotation for this `attr`. Docs without annotation are skipped for scoring purposes. + labeled (bool): Whether or not to include label information in + the evaluation. If set to 'False', two spans will be considered + equal if their start and end match, irrespective of their label. + allow_overlap (bool): Whether or not to allow overlapping spans. + If set to 'False', the alignment will automatically resolve conflicts. RETURNS (Dict[str, Any]): A dictionary containing the PRF scores under the keys attr_p/r/f and the per-type PRF scores under attr_per_type. @@ -351,33 +358,42 @@ class Scorer: gold_spans = set() pred_spans = set() for span in getter(gold_doc, attr): - gold_span = (span.label_, span.start, span.end - 1) + if labeled: + gold_span = (span.label_, span.start, span.end - 1) + else: + gold_span = (span.start, span.end - 1) gold_spans.add(gold_span) - gold_per_type[span.label_].add((span.label_, span.start, span.end - 1)) + gold_per_type[span.label_].add(gold_span) pred_per_type = {label: set() for label in labels} - for span in example.get_aligned_spans_x2y(getter(pred_doc, attr)): - pred_spans.add((span.label_, span.start, span.end - 1)) - pred_per_type[span.label_].add((span.label_, span.start, span.end - 1)) + for span in example.get_aligned_spans_x2y(getter(pred_doc, attr), allow_overlap): + if labeled: + pred_span = (span.label_, span.start, span.end - 1) + else: + pred_span = (span.start, span.end - 1) + pred_spans.add(pred_span) + pred_per_type[span.label_].add(pred_span) # Scores per label - for k, v in score_per_type.items(): - if k in pred_per_type: - v.score_set(pred_per_type[k], gold_per_type[k]) + if labeled: + for k, v in score_per_type.items(): + if k in pred_per_type: + v.score_set(pred_per_type[k], gold_per_type[k]) # Score for all labels score.score_set(pred_spans, gold_spans) - if len(score) > 0: - return { - f"{attr}_p": score.precision, - f"{attr}_r": score.recall, - f"{attr}_f": score.fscore, - f"{attr}_per_type": {k: v.to_dict() for k, v in score_per_type.items()}, - } - else: - return { + # Assemble final result + final_scores = { f"{attr}_p": None, f"{attr}_r": None, f"{attr}_f": None, - f"{attr}_per_type": None, } + if labeled: + final_scores[f"{attr}_per_type"] = None + if len(score) > 0: + final_scores[f"{attr}_p"] = score.precision + final_scores[f"{attr}_r"] = score.recall + final_scores[f"{attr}_f"] = score.fscore + if labeled: + final_scores[f"{attr}_per_type"] = {k: v.to_dict() for k, v in score_per_type.items()} + return final_scores @staticmethod def score_cats( diff --git a/spacy/tests/test_scorer.py b/spacy/tests/test_scorer.py index ecdaee768..c044d8afe 100644 --- a/spacy/tests/test_scorer.py +++ b/spacy/tests/test_scorer.py @@ -6,7 +6,7 @@ from spacy.training.iob_utils import offsets_to_biluo_tags from spacy.scorer import Scorer, ROCAUCScore, PRFScore from spacy.scorer import _roc_auc_score, _roc_curve from spacy.lang.en import English -from spacy.tokens import Doc +from spacy.tokens import Doc, Span test_las_apple = [ @@ -405,6 +405,51 @@ def test_roc_auc_score(): _ = score.score # noqa: F841 +def test_score_spans(): + nlp = English() + text = "This is just a random sentence." + key = "my_spans" + gold = nlp.make_doc(text) + pred = nlp.make_doc(text) + spans = [] + spans.append(gold.char_span(0, 4, label="PERSON")) + spans.append(gold.char_span(0, 7, label="ORG")) + spans.append(gold.char_span(8, 12, label="ORG")) + gold.spans[key] = spans + + def span_getter(doc, span_key): + return doc.spans[span_key] + + # Predict exactly the same, but overlapping spans will be discarded + pred.spans[key] = spans + eg = Example(pred, gold) + scores = Scorer.score_spans([eg], attr=key, getter=span_getter) + assert scores[f"{key}_p"] == 1.0 + assert scores[f"{key}_r"] < 1.0 + + # Allow overlapping, now both precision and recall should be 100% + pred.spans[key] = spans + eg = Example(pred, gold) + scores = Scorer.score_spans([eg], attr=key, getter=span_getter, allow_overlap=True) + assert scores[f"{key}_p"] == 1.0 + assert scores[f"{key}_r"] == 1.0 + + # Change the predicted labels + new_spans = [Span(pred, span.start, span.end, label="WRONG") for span in spans] + pred.spans[key] = new_spans + eg = Example(pred, gold) + scores = Scorer.score_spans([eg], attr=key, getter=span_getter, allow_overlap=True) + assert scores[f"{key}_p"] == 0.0 + assert scores[f"{key}_r"] == 0.0 + assert f"{key}_per_type" in scores + + # Discard labels from the evaluation + scores = Scorer.score_spans([eg], attr=key, getter=span_getter, allow_overlap=True, labeled=False) + assert scores[f"{key}_p"] == 1.0 + assert scores[f"{key}_r"] == 1.0 + assert f"{key}_per_type" not in scores + + def test_prf_score(): cand = {"hi", "ho"} gold1 = {"yo", "hi"} @@ -422,4 +467,4 @@ def test_prf_score(): assert (c.precision, c.recall, c.fscore) == approx((0.25, 0.5, 0.33333333)) a += b - assert (a.precision, a.recall, a.fscore) == approx((c.precision, c.recall, c.fscore)) + assert (a.precision, a.recall, a.fscore) == approx((c.precision, c.recall, c.fscore)) \ No newline at end of file diff --git a/spacy/tests/training/test_training.py b/spacy/tests/training/test_training.py index c7a85bf87..321c08c1e 100644 --- a/spacy/tests/training/test_training.py +++ b/spacy/tests/training/test_training.py @@ -426,6 +426,29 @@ def test_aligned_spans_x2y(en_vocab, en_tokenizer): assert [(ent.start, ent.end) for ent in ents_x2y] == [(0, 2), (4, 6)] +def test_aligned_spans_y2x_overlap(en_vocab, en_tokenizer): + text = "I flew to San Francisco Valley" + nlp = English() + doc = nlp(text) + # the reference doc has overlapping spans + gold_doc = nlp.make_doc(text) + spans = [] + prefix = "I flew to " + spans.append(gold_doc.char_span(len(prefix), len(prefix + "San Francisco"), label="CITY")) + spans.append(gold_doc.char_span(len(prefix), len(prefix + "San Francisco Valley"), label="VALLEY")) + spans_key = "overlap_ents" + gold_doc.spans[spans_key] = spans + example = Example(doc, gold_doc) + spans_gold = example.reference.spans[spans_key] + assert [(ent.start, ent.end) for ent in spans_gold] == [(3, 5), (3, 6)] + + # Ensure that 'get_aligned_spans_y2x' has the aligned entities correct + spans_y2x_no_overlap = example.get_aligned_spans_y2x(spans_gold, allow_overlap=False) + assert [(ent.start, ent.end) for ent in spans_y2x_no_overlap] == [(3, 5)] + spans_y2x_overlap = example.get_aligned_spans_y2x(spans_gold, allow_overlap=True) + assert [(ent.start, ent.end) for ent in spans_y2x_overlap] == [(3, 5), (3, 6)] + + def test_gold_ner_missing_tags(en_tokenizer): doc = en_tokenizer("I flew to Silicon Valley via London.") biluo_tags = [None, "O", "O", "B-LOC", "L-LOC", "O", "U-GPE", "O"] diff --git a/spacy/training/example.pyx b/spacy/training/example.pyx index 9cf825bf9..74af793bd 100644 --- a/spacy/training/example.pyx +++ b/spacy/training/example.pyx @@ -213,18 +213,19 @@ cdef class Example: else: return [None] * len(self.x) - def get_aligned_spans_x2y(self, x_spans): - return self._get_aligned_spans(self.y, x_spans, self.alignment.x2y) + def get_aligned_spans_x2y(self, x_spans, allow_overlap=False): + return self._get_aligned_spans(self.y, x_spans, self.alignment.x2y, allow_overlap) - def get_aligned_spans_y2x(self, y_spans): - return self._get_aligned_spans(self.x, y_spans, self.alignment.y2x) + def get_aligned_spans_y2x(self, y_spans, allow_overlap=False): + return self._get_aligned_spans(self.x, y_spans, self.alignment.y2x, allow_overlap) - def _get_aligned_spans(self, doc, spans, align): + def _get_aligned_spans(self, doc, spans, align, allow_overlap): seen = set() output = [] for span in spans: indices = align[span.start : span.end].data.ravel() - indices = [idx for idx in indices if idx not in seen] + if not allow_overlap: + indices = [idx for idx in indices if idx not in seen] if len(indices) >= 1: aligned_span = Span(doc, indices[0], indices[-1] + 1, label=span.label) target_text = span.text.lower().strip().replace(" ", "") @@ -237,7 +238,7 @@ cdef class Example: def get_aligned_ner(self): if not self.y.has_annotation("ENT_IOB"): return [None] * len(self.x) # should this be 'missing' instead of 'None' ? - x_ents = self.get_aligned_spans_y2x(self.y.ents) + x_ents = self.get_aligned_spans_y2x(self.y.ents, allow_overlap=False) # Default to 'None' for missing values x_tags = offsets_to_biluo_tags( self.x, diff --git a/website/docs/api/example.md b/website/docs/api/example.md index 2811f4d91..ca9d3c056 100644 --- a/website/docs/api/example.md +++ b/website/docs/api/example.md @@ -33,8 +33,8 @@ both documents. | Name | Description | | -------------- | ------------------------------------------------------------------------------------------------------------------------ | -| `predicted` | The document containing (partial) predictions. Cannot be `None`. ~~Doc~~ | -| `reference` | The document containing gold-standard annotations. Cannot be `None`. ~~Doc~~ | +| `predicted` | The document containing (partial) predictions. Cannot be `None`. ~~Doc~~ | +| `reference` | The document containing gold-standard annotations. Cannot be `None`. ~~Doc~~ | | _keyword-only_ | | | `alignment` | An object holding the alignment between the tokens of the `predicted` and `reference` documents. ~~Optional[Alignment]~~ | @@ -56,11 +56,11 @@ see the [training format documentation](/api/data-formats#dict-input). > example = Example.from_dict(predicted, {"words": token_ref, "tags": tags_ref}) > ``` -| Name | Description | -| -------------- | ------------------------------------------------------------------------- | -| `predicted` | The document containing (partial) predictions. Cannot be `None`. ~~Doc~~ | -| `example_dict` | `Dict[str, obj]` | The gold-standard annotations as a dictionary. Cannot be `None`. ~~Dict[str, Any]~~ | -| **RETURNS** | The newly constructed object. ~~Example~~ | +| Name | Description | +| -------------- | ----------------------------------------------------------------------------------- | +| `predicted` | The document containing (partial) predictions. Cannot be `None`. ~~Doc~~ | +| `example_dict` | The gold-standard annotations as a dictionary. Cannot be `None`. ~~Dict[str, Any]~~ | +| **RETURNS** | The newly constructed object. ~~Example~~ | ## Example.text {#text tag="property"} @@ -211,10 +211,11 @@ align to the tokenization in [`Example.predicted`](/api/example#predicted). > assert [(ent.start, ent.end) for ent in ents_y2x] == [(0, 1)] > ``` -| Name | Description | -| ----------- | ----------------------------------------------------------------------------- | -| `y_spans` | `Span` objects aligned to the tokenization of `reference`. ~~Iterable[Span]~~ | -| **RETURNS** | `Span` objects aligned to the tokenization of `predicted`. ~~List[Span]~~ | +| Name | Description | +| --------------- | -------------------------------------------------------------------------------------------- | +| `y_spans` | `Span` objects aligned to the tokenization of `reference`. ~~Iterable[Span]~~ | +| `allow_overlap` | Whether the resulting `Span` objects may overlap or not. Set to `False` by default. ~~bool~~ | +| **RETURNS** | `Span` objects aligned to the tokenization of `predicted`. ~~List[Span]~~ | ## Example.get_aligned_spans_x2y {#get_aligned_spans_x2y tag="method"} @@ -238,10 +239,11 @@ against the original gold-standard annotation. > assert [(ent.start, ent.end) for ent in ents_x2y] == [(0, 2)] > ``` -| Name | Description | -| ----------- | ----------------------------------------------------------------------------- | -| `x_spans` | `Span` objects aligned to the tokenization of `predicted`. ~~Iterable[Span]~~ | -| **RETURNS** | `Span` objects aligned to the tokenization of `reference`. ~~List[Span]~~ | +| Name | Description | +| --------------- | -------------------------------------------------------------------------------------------- | +| `x_spans` | `Span` objects aligned to the tokenization of `predicted`. ~~Iterable[Span]~~ | +| `allow_overlap` | Whether the resulting `Span` objects may overlap or not. Set to `False` by default. ~~bool~~ | +| **RETURNS** | `Span` objects aligned to the tokenization of `reference`. ~~List[Span]~~ | ## Example.to_dict {#to_dict tag="method"} diff --git a/website/docs/api/scorer.md b/website/docs/api/scorer.md index cf1a1ca1f..7398bae81 100644 --- a/website/docs/api/scorer.md +++ b/website/docs/api/scorer.md @@ -137,14 +137,16 @@ Returns PRF scores for labeled or unlabeled spans. > print(scores["ents_f"]) > ``` -| Name | Description | -| ---------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| `examples` | The `Example` objects holding both the predictions and the correct gold-standard annotations. ~~Iterable[Example]~~ | -| `attr` | The attribute to score. ~~str~~ | -| _keyword-only_ | | -| `getter` | Defaults to `getattr`. If provided, `getter(doc, attr)` should return the `Span` objects for an individual `Doc`. ~~Callable[[Doc, str], Iterable[Span]]~~ | -| `has_annotation` | Defaults to `None`. If provided, `has_annotation(doc)` should return whether a `Doc` has annotation for this `attr`. Docs without annotation are skipped for scoring purposes. ~~Optional[Callable[[Doc], bool]]~~ | -| **RETURNS** | A dictionary containing the PRF scores under the keys `{attr}_p`, `{attr}_r`, `{attr}_f` and the per-type PRF scores under `{attr}_per_type`. ~~Dict[str, Union[float, Dict[str, float]]]~~ | +| Name | Description | +| ---------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `examples` | The `Example` objects holding both the predictions and the correct gold-standard annotations. ~~Iterable[Example]~~ | +| `attr` | The attribute to score. ~~str~~ | +| _keyword-only_ | | +| `getter` | Defaults to `getattr`. If provided, `getter(doc, attr)` should return the `Span` objects for an individual `Doc`. ~~Callable[[Doc, str], Iterable[Span]]~~ | +| `has_annotation` | Defaults to `None`. If provided, `has_annotation(doc)` should return whether a `Doc` has annotation for this `attr`. Docs without annotation are skipped for scoring purposes. ~~str~~ | +| `labeled` | Defaults to `True`. If set to `False`, two spans will be considered equal if their start and end match, irrespective of their label. ~~bool~~ | +| `allow_overlap` | Defaults to `False`. Whether or not to allow overlapping spans. If set to `False`, the alignment will automatically resolve conflicts. ~~bool~~ | +| **RETURNS** | A dictionary containing the PRF scores under the keys `{attr}_p`, `{attr}_r`, `{attr}_f` and the per-type PRF scores under `{attr}_per_type`. ~~Dict[str, Union[float, Dict[str, float]]]~~ | ## Scorer.score_deps {#score_deps tag="staticmethod" new="3"}