mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 01:16:28 +03:00
Extend score_spans for overlapping & non-labeled spans (#7209)
* extend span scorer with consider_label and allow_overlap * unit test for spans y2x overlap * add score_spans unit test * docs for new fields in scorer.score_spans * rename to include_label * spell out if-else for clarity * rename to 'labeled' Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
parent
c362006cb9
commit
204c2f116b
|
@ -311,6 +311,8 @@ class Scorer:
|
||||||
*,
|
*,
|
||||||
getter: Callable[[Doc, str], Iterable[Span]] = getattr,
|
getter: Callable[[Doc, str], Iterable[Span]] = getattr,
|
||||||
has_annotation: Optional[Callable[[Doc], bool]] = None,
|
has_annotation: Optional[Callable[[Doc], bool]] = None,
|
||||||
|
labeled: bool = True,
|
||||||
|
allow_overlap: bool = False,
|
||||||
**cfg,
|
**cfg,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Returns PRF scores for labeled spans.
|
"""Returns PRF scores for labeled spans.
|
||||||
|
@ -323,6 +325,11 @@ class Scorer:
|
||||||
has_annotation (Optional[Callable[[Doc], bool]]) should return whether a `Doc`
|
has_annotation (Optional[Callable[[Doc], bool]]) should return whether a `Doc`
|
||||||
has annotation for this `attr`. Docs without annotation are skipped for
|
has annotation for this `attr`. Docs without annotation are skipped for
|
||||||
scoring purposes.
|
scoring purposes.
|
||||||
|
labeled (bool): Whether or not to include label information in
|
||||||
|
the evaluation. If set to 'False', two spans will be considered
|
||||||
|
equal if their start and end match, irrespective of their label.
|
||||||
|
allow_overlap (bool): Whether or not to allow overlapping spans.
|
||||||
|
If set to 'False', the alignment will automatically resolve conflicts.
|
||||||
RETURNS (Dict[str, Any]): A dictionary containing the PRF scores under
|
RETURNS (Dict[str, Any]): A dictionary containing the PRF scores under
|
||||||
the keys attr_p/r/f and the per-type PRF scores under attr_per_type.
|
the keys attr_p/r/f and the per-type PRF scores under attr_per_type.
|
||||||
|
|
||||||
|
@ -351,33 +358,42 @@ class Scorer:
|
||||||
gold_spans = set()
|
gold_spans = set()
|
||||||
pred_spans = set()
|
pred_spans = set()
|
||||||
for span in getter(gold_doc, attr):
|
for span in getter(gold_doc, attr):
|
||||||
gold_span = (span.label_, span.start, span.end - 1)
|
if labeled:
|
||||||
|
gold_span = (span.label_, span.start, span.end - 1)
|
||||||
|
else:
|
||||||
|
gold_span = (span.start, span.end - 1)
|
||||||
gold_spans.add(gold_span)
|
gold_spans.add(gold_span)
|
||||||
gold_per_type[span.label_].add((span.label_, span.start, span.end - 1))
|
gold_per_type[span.label_].add(gold_span)
|
||||||
pred_per_type = {label: set() for label in labels}
|
pred_per_type = {label: set() for label in labels}
|
||||||
for span in example.get_aligned_spans_x2y(getter(pred_doc, attr)):
|
for span in example.get_aligned_spans_x2y(getter(pred_doc, attr), allow_overlap):
|
||||||
pred_spans.add((span.label_, span.start, span.end - 1))
|
if labeled:
|
||||||
pred_per_type[span.label_].add((span.label_, span.start, span.end - 1))
|
pred_span = (span.label_, span.start, span.end - 1)
|
||||||
|
else:
|
||||||
|
pred_span = (span.start, span.end - 1)
|
||||||
|
pred_spans.add(pred_span)
|
||||||
|
pred_per_type[span.label_].add(pred_span)
|
||||||
# Scores per label
|
# Scores per label
|
||||||
for k, v in score_per_type.items():
|
if labeled:
|
||||||
if k in pred_per_type:
|
for k, v in score_per_type.items():
|
||||||
v.score_set(pred_per_type[k], gold_per_type[k])
|
if k in pred_per_type:
|
||||||
|
v.score_set(pred_per_type[k], gold_per_type[k])
|
||||||
# Score for all labels
|
# Score for all labels
|
||||||
score.score_set(pred_spans, gold_spans)
|
score.score_set(pred_spans, gold_spans)
|
||||||
if len(score) > 0:
|
# Assemble final result
|
||||||
return {
|
final_scores = {
|
||||||
f"{attr}_p": score.precision,
|
|
||||||
f"{attr}_r": score.recall,
|
|
||||||
f"{attr}_f": score.fscore,
|
|
||||||
f"{attr}_per_type": {k: v.to_dict() for k, v in score_per_type.items()},
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
f"{attr}_p": None,
|
f"{attr}_p": None,
|
||||||
f"{attr}_r": None,
|
f"{attr}_r": None,
|
||||||
f"{attr}_f": None,
|
f"{attr}_f": None,
|
||||||
f"{attr}_per_type": None,
|
|
||||||
}
|
}
|
||||||
|
if labeled:
|
||||||
|
final_scores[f"{attr}_per_type"] = None
|
||||||
|
if len(score) > 0:
|
||||||
|
final_scores[f"{attr}_p"] = score.precision
|
||||||
|
final_scores[f"{attr}_r"] = score.recall
|
||||||
|
final_scores[f"{attr}_f"] = score.fscore
|
||||||
|
if labeled:
|
||||||
|
final_scores[f"{attr}_per_type"] = {k: v.to_dict() for k, v in score_per_type.items()}
|
||||||
|
return final_scores
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def score_cats(
|
def score_cats(
|
||||||
|
|
|
@ -6,7 +6,7 @@ from spacy.training.iob_utils import offsets_to_biluo_tags
|
||||||
from spacy.scorer import Scorer, ROCAUCScore, PRFScore
|
from spacy.scorer import Scorer, ROCAUCScore, PRFScore
|
||||||
from spacy.scorer import _roc_auc_score, _roc_curve
|
from spacy.scorer import _roc_auc_score, _roc_curve
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc, Span
|
||||||
|
|
||||||
|
|
||||||
test_las_apple = [
|
test_las_apple = [
|
||||||
|
@ -405,6 +405,51 @@ def test_roc_auc_score():
|
||||||
_ = score.score # noqa: F841
|
_ = score.score # noqa: F841
|
||||||
|
|
||||||
|
|
||||||
|
def test_score_spans():
|
||||||
|
nlp = English()
|
||||||
|
text = "This is just a random sentence."
|
||||||
|
key = "my_spans"
|
||||||
|
gold = nlp.make_doc(text)
|
||||||
|
pred = nlp.make_doc(text)
|
||||||
|
spans = []
|
||||||
|
spans.append(gold.char_span(0, 4, label="PERSON"))
|
||||||
|
spans.append(gold.char_span(0, 7, label="ORG"))
|
||||||
|
spans.append(gold.char_span(8, 12, label="ORG"))
|
||||||
|
gold.spans[key] = spans
|
||||||
|
|
||||||
|
def span_getter(doc, span_key):
|
||||||
|
return doc.spans[span_key]
|
||||||
|
|
||||||
|
# Predict exactly the same, but overlapping spans will be discarded
|
||||||
|
pred.spans[key] = spans
|
||||||
|
eg = Example(pred, gold)
|
||||||
|
scores = Scorer.score_spans([eg], attr=key, getter=span_getter)
|
||||||
|
assert scores[f"{key}_p"] == 1.0
|
||||||
|
assert scores[f"{key}_r"] < 1.0
|
||||||
|
|
||||||
|
# Allow overlapping, now both precision and recall should be 100%
|
||||||
|
pred.spans[key] = spans
|
||||||
|
eg = Example(pred, gold)
|
||||||
|
scores = Scorer.score_spans([eg], attr=key, getter=span_getter, allow_overlap=True)
|
||||||
|
assert scores[f"{key}_p"] == 1.0
|
||||||
|
assert scores[f"{key}_r"] == 1.0
|
||||||
|
|
||||||
|
# Change the predicted labels
|
||||||
|
new_spans = [Span(pred, span.start, span.end, label="WRONG") for span in spans]
|
||||||
|
pred.spans[key] = new_spans
|
||||||
|
eg = Example(pred, gold)
|
||||||
|
scores = Scorer.score_spans([eg], attr=key, getter=span_getter, allow_overlap=True)
|
||||||
|
assert scores[f"{key}_p"] == 0.0
|
||||||
|
assert scores[f"{key}_r"] == 0.0
|
||||||
|
assert f"{key}_per_type" in scores
|
||||||
|
|
||||||
|
# Discard labels from the evaluation
|
||||||
|
scores = Scorer.score_spans([eg], attr=key, getter=span_getter, allow_overlap=True, labeled=False)
|
||||||
|
assert scores[f"{key}_p"] == 1.0
|
||||||
|
assert scores[f"{key}_r"] == 1.0
|
||||||
|
assert f"{key}_per_type" not in scores
|
||||||
|
|
||||||
|
|
||||||
def test_prf_score():
|
def test_prf_score():
|
||||||
cand = {"hi", "ho"}
|
cand = {"hi", "ho"}
|
||||||
gold1 = {"yo", "hi"}
|
gold1 = {"yo", "hi"}
|
||||||
|
|
|
@ -426,6 +426,29 @@ def test_aligned_spans_x2y(en_vocab, en_tokenizer):
|
||||||
assert [(ent.start, ent.end) for ent in ents_x2y] == [(0, 2), (4, 6)]
|
assert [(ent.start, ent.end) for ent in ents_x2y] == [(0, 2), (4, 6)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_aligned_spans_y2x_overlap(en_vocab, en_tokenizer):
|
||||||
|
text = "I flew to San Francisco Valley"
|
||||||
|
nlp = English()
|
||||||
|
doc = nlp(text)
|
||||||
|
# the reference doc has overlapping spans
|
||||||
|
gold_doc = nlp.make_doc(text)
|
||||||
|
spans = []
|
||||||
|
prefix = "I flew to "
|
||||||
|
spans.append(gold_doc.char_span(len(prefix), len(prefix + "San Francisco"), label="CITY"))
|
||||||
|
spans.append(gold_doc.char_span(len(prefix), len(prefix + "San Francisco Valley"), label="VALLEY"))
|
||||||
|
spans_key = "overlap_ents"
|
||||||
|
gold_doc.spans[spans_key] = spans
|
||||||
|
example = Example(doc, gold_doc)
|
||||||
|
spans_gold = example.reference.spans[spans_key]
|
||||||
|
assert [(ent.start, ent.end) for ent in spans_gold] == [(3, 5), (3, 6)]
|
||||||
|
|
||||||
|
# Ensure that 'get_aligned_spans_y2x' has the aligned entities correct
|
||||||
|
spans_y2x_no_overlap = example.get_aligned_spans_y2x(spans_gold, allow_overlap=False)
|
||||||
|
assert [(ent.start, ent.end) for ent in spans_y2x_no_overlap] == [(3, 5)]
|
||||||
|
spans_y2x_overlap = example.get_aligned_spans_y2x(spans_gold, allow_overlap=True)
|
||||||
|
assert [(ent.start, ent.end) for ent in spans_y2x_overlap] == [(3, 5), (3, 6)]
|
||||||
|
|
||||||
|
|
||||||
def test_gold_ner_missing_tags(en_tokenizer):
|
def test_gold_ner_missing_tags(en_tokenizer):
|
||||||
doc = en_tokenizer("I flew to Silicon Valley via London.")
|
doc = en_tokenizer("I flew to Silicon Valley via London.")
|
||||||
biluo_tags = [None, "O", "O", "B-LOC", "L-LOC", "O", "U-GPE", "O"]
|
biluo_tags = [None, "O", "O", "B-LOC", "L-LOC", "O", "U-GPE", "O"]
|
||||||
|
|
|
@ -213,18 +213,19 @@ cdef class Example:
|
||||||
else:
|
else:
|
||||||
return [None] * len(self.x)
|
return [None] * len(self.x)
|
||||||
|
|
||||||
def get_aligned_spans_x2y(self, x_spans):
|
def get_aligned_spans_x2y(self, x_spans, allow_overlap=False):
|
||||||
return self._get_aligned_spans(self.y, x_spans, self.alignment.x2y)
|
return self._get_aligned_spans(self.y, x_spans, self.alignment.x2y, allow_overlap)
|
||||||
|
|
||||||
def get_aligned_spans_y2x(self, y_spans):
|
def get_aligned_spans_y2x(self, y_spans, allow_overlap=False):
|
||||||
return self._get_aligned_spans(self.x, y_spans, self.alignment.y2x)
|
return self._get_aligned_spans(self.x, y_spans, self.alignment.y2x, allow_overlap)
|
||||||
|
|
||||||
def _get_aligned_spans(self, doc, spans, align):
|
def _get_aligned_spans(self, doc, spans, align, allow_overlap):
|
||||||
seen = set()
|
seen = set()
|
||||||
output = []
|
output = []
|
||||||
for span in spans:
|
for span in spans:
|
||||||
indices = align[span.start : span.end].data.ravel()
|
indices = align[span.start : span.end].data.ravel()
|
||||||
indices = [idx for idx in indices if idx not in seen]
|
if not allow_overlap:
|
||||||
|
indices = [idx for idx in indices if idx not in seen]
|
||||||
if len(indices) >= 1:
|
if len(indices) >= 1:
|
||||||
aligned_span = Span(doc, indices[0], indices[-1] + 1, label=span.label)
|
aligned_span = Span(doc, indices[0], indices[-1] + 1, label=span.label)
|
||||||
target_text = span.text.lower().strip().replace(" ", "")
|
target_text = span.text.lower().strip().replace(" ", "")
|
||||||
|
@ -237,7 +238,7 @@ cdef class Example:
|
||||||
def get_aligned_ner(self):
|
def get_aligned_ner(self):
|
||||||
if not self.y.has_annotation("ENT_IOB"):
|
if not self.y.has_annotation("ENT_IOB"):
|
||||||
return [None] * len(self.x) # should this be 'missing' instead of 'None' ?
|
return [None] * len(self.x) # should this be 'missing' instead of 'None' ?
|
||||||
x_ents = self.get_aligned_spans_y2x(self.y.ents)
|
x_ents = self.get_aligned_spans_y2x(self.y.ents, allow_overlap=False)
|
||||||
# Default to 'None' for missing values
|
# Default to 'None' for missing values
|
||||||
x_tags = offsets_to_biluo_tags(
|
x_tags = offsets_to_biluo_tags(
|
||||||
self.x,
|
self.x,
|
||||||
|
|
|
@ -33,8 +33,8 @@ both documents.
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------ |
|
| -------------- | ------------------------------------------------------------------------------------------------------------------------ |
|
||||||
| `predicted` | The document containing (partial) predictions. Cannot be `None`. ~~Doc~~ |
|
| `predicted` | The document containing (partial) predictions. Cannot be `None`. ~~Doc~~ |
|
||||||
| `reference` | The document containing gold-standard annotations. Cannot be `None`. ~~Doc~~ |
|
| `reference` | The document containing gold-standard annotations. Cannot be `None`. ~~Doc~~ |
|
||||||
| _keyword-only_ | |
|
| _keyword-only_ | |
|
||||||
| `alignment` | An object holding the alignment between the tokens of the `predicted` and `reference` documents. ~~Optional[Alignment]~~ |
|
| `alignment` | An object holding the alignment between the tokens of the `predicted` and `reference` documents. ~~Optional[Alignment]~~ |
|
||||||
|
|
||||||
|
@ -56,11 +56,11 @@ see the [training format documentation](/api/data-formats#dict-input).
|
||||||
> example = Example.from_dict(predicted, {"words": token_ref, "tags": tags_ref})
|
> example = Example.from_dict(predicted, {"words": token_ref, "tags": tags_ref})
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| -------------- | ------------------------------------------------------------------------- |
|
| -------------- | ----------------------------------------------------------------------------------- |
|
||||||
| `predicted` | The document containing (partial) predictions. Cannot be `None`. ~~Doc~~ |
|
| `predicted` | The document containing (partial) predictions. Cannot be `None`. ~~Doc~~ |
|
||||||
| `example_dict` | `Dict[str, obj]` | The gold-standard annotations as a dictionary. Cannot be `None`. ~~Dict[str, Any]~~ |
|
| `example_dict` | The gold-standard annotations as a dictionary. Cannot be `None`. ~~Dict[str, Any]~~ |
|
||||||
| **RETURNS** | The newly constructed object. ~~Example~~ |
|
| **RETURNS** | The newly constructed object. ~~Example~~ |
|
||||||
|
|
||||||
## Example.text {#text tag="property"}
|
## Example.text {#text tag="property"}
|
||||||
|
|
||||||
|
@ -211,10 +211,11 @@ align to the tokenization in [`Example.predicted`](/api/example#predicted).
|
||||||
> assert [(ent.start, ent.end) for ent in ents_y2x] == [(0, 1)]
|
> assert [(ent.start, ent.end) for ent in ents_y2x] == [(0, 1)]
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| ----------- | ----------------------------------------------------------------------------- |
|
| --------------- | -------------------------------------------------------------------------------------------- |
|
||||||
| `y_spans` | `Span` objects aligned to the tokenization of `reference`. ~~Iterable[Span]~~ |
|
| `y_spans` | `Span` objects aligned to the tokenization of `reference`. ~~Iterable[Span]~~ |
|
||||||
| **RETURNS** | `Span` objects aligned to the tokenization of `predicted`. ~~List[Span]~~ |
|
| `allow_overlap` | Whether the resulting `Span` objects may overlap or not. Set to `False` by default. ~~bool~~ |
|
||||||
|
| **RETURNS** | `Span` objects aligned to the tokenization of `predicted`. ~~List[Span]~~ |
|
||||||
|
|
||||||
## Example.get_aligned_spans_x2y {#get_aligned_spans_x2y tag="method"}
|
## Example.get_aligned_spans_x2y {#get_aligned_spans_x2y tag="method"}
|
||||||
|
|
||||||
|
@ -238,10 +239,11 @@ against the original gold-standard annotation.
|
||||||
> assert [(ent.start, ent.end) for ent in ents_x2y] == [(0, 2)]
|
> assert [(ent.start, ent.end) for ent in ents_x2y] == [(0, 2)]
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| ----------- | ----------------------------------------------------------------------------- |
|
| --------------- | -------------------------------------------------------------------------------------------- |
|
||||||
| `x_spans` | `Span` objects aligned to the tokenization of `predicted`. ~~Iterable[Span]~~ |
|
| `x_spans` | `Span` objects aligned to the tokenization of `predicted`. ~~Iterable[Span]~~ |
|
||||||
| **RETURNS** | `Span` objects aligned to the tokenization of `reference`. ~~List[Span]~~ |
|
| `allow_overlap` | Whether the resulting `Span` objects may overlap or not. Set to `False` by default. ~~bool~~ |
|
||||||
|
| **RETURNS** | `Span` objects aligned to the tokenization of `reference`. ~~List[Span]~~ |
|
||||||
|
|
||||||
## Example.to_dict {#to_dict tag="method"}
|
## Example.to_dict {#to_dict tag="method"}
|
||||||
|
|
||||||
|
|
|
@ -137,14 +137,16 @@ Returns PRF scores for labeled or unlabeled spans.
|
||||||
> print(scores["ents_f"])
|
> print(scores["ents_f"])
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| ---------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
| ---------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| `examples` | The `Example` objects holding both the predictions and the correct gold-standard annotations. ~~Iterable[Example]~~ |
|
| `examples` | The `Example` objects holding both the predictions and the correct gold-standard annotations. ~~Iterable[Example]~~ |
|
||||||
| `attr` | The attribute to score. ~~str~~ |
|
| `attr` | The attribute to score. ~~str~~ |
|
||||||
| _keyword-only_ | |
|
| _keyword-only_ | |
|
||||||
| `getter` | Defaults to `getattr`. If provided, `getter(doc, attr)` should return the `Span` objects for an individual `Doc`. ~~Callable[[Doc, str], Iterable[Span]]~~ |
|
| `getter` | Defaults to `getattr`. If provided, `getter(doc, attr)` should return the `Span` objects for an individual `Doc`. ~~Callable[[Doc, str], Iterable[Span]]~~ |
|
||||||
| `has_annotation` | Defaults to `None`. If provided, `has_annotation(doc)` should return whether a `Doc` has annotation for this `attr`. Docs without annotation are skipped for scoring purposes. ~~Optional[Callable[[Doc], bool]]~~ |
|
| `has_annotation` | Defaults to `None`. If provided, `has_annotation(doc)` should return whether a `Doc` has annotation for this `attr`. Docs without annotation are skipped for scoring purposes. ~~str~~ |
|
||||||
| **RETURNS** | A dictionary containing the PRF scores under the keys `{attr}_p`, `{attr}_r`, `{attr}_f` and the per-type PRF scores under `{attr}_per_type`. ~~Dict[str, Union[float, Dict[str, float]]]~~ |
|
| `labeled` | Defaults to `True`. If set to `False`, two spans will be considered equal if their start and end match, irrespective of their label. ~~bool~~ |
|
||||||
|
| `allow_overlap` | Defaults to `False`. Whether or not to allow overlapping spans. If set to `False`, the alignment will automatically resolve conflicts. ~~bool~~ |
|
||||||
|
| **RETURNS** | A dictionary containing the PRF scores under the keys `{attr}_p`, `{attr}_r`, `{attr}_f` and the per-type PRF scores under `{attr}_per_type`. ~~Dict[str, Union[float, Dict[str, float]]]~~ |
|
||||||
|
|
||||||
## Scorer.score_deps {#score_deps tag="staticmethod" new="3"}
|
## Scorer.score_deps {#score_deps tag="staticmethod" new="3"}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user