2019-09-15 23:31:31 +03:00
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
from .errors import Errors
|
2015-05-27 04:18:16 +03:00
|
|
|
|
|
2015-04-05 23:29:30 +03:00
|
|
|
|
|
2015-05-24 21:07:18 +03:00
|
|
|
|
class PRFScore(object):
|
2017-04-15 12:59:21 +03:00
|
|
|
|
"""
|
|
|
|
|
A precision / recall / F score
|
|
|
|
|
"""
|
💫 Tidy up and auto-format .py files (#2983)
<!--- Provide a general summary of your changes in the title. -->
## Description
- [x] Use [`black`](https://github.com/ambv/black) to auto-format all `.py` files.
- [x] Update flake8 config to exclude very large files (lemmatization tables etc.)
- [x] Update code to be compatible with flake8 rules
- [x] Fix various small bugs, inconsistencies and messy stuff in the language data
- [x] Update docs to explain new code style (`black`, `flake8`, when to use `# fmt: off` and `# fmt: on` and what `# noqa` means)
Once #2932 is merged, which auto-formats and tidies up the CLI, we'll be able to run `flake8 spacy` actually get meaningful results.
At the moment, the code style and linting isn't applied automatically, but I'm hoping that the new [GitHub Actions](https://github.com/features/actions) will let us auto-format pull requests and post comments with relevant linting information.
### Types of change
enhancement, code style
## Checklist
<!--- Before you submit the PR, go over this checklist and make sure you can
tick off all the boxes. [] -> [x] -->
- [x] I have submitted the spaCy Contributor Agreement.
- [x] I ran the tests, and all new and existing tests passed.
- [x] My changes don't require a change to the documentation, or if they do, I've added all required information.
2018-11-30 19:03:03 +03:00
|
|
|
|
|
2015-05-24 21:07:18 +03:00
|
|
|
|
def __init__(self):
|
|
|
|
|
self.tp = 0
|
|
|
|
|
self.fp = 0
|
|
|
|
|
self.fn = 0
|
|
|
|
|
|
|
|
|
|
def score_set(self, cand, gold):
|
|
|
|
|
self.tp += len(cand.intersection(gold))
|
|
|
|
|
self.fp += len(cand - gold)
|
|
|
|
|
self.fn += len(gold - cand)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def precision(self):
|
|
|
|
|
return self.tp / (self.tp + self.fp + 1e-100)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def recall(self):
|
|
|
|
|
return self.tp / (self.tp + self.fn + 1e-100)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def fscore(self):
|
|
|
|
|
p = self.precision
|
|
|
|
|
r = self.recall
|
|
|
|
|
return 2 * ((p * r) / (p + r + 1e-100))
|
|
|
|
|
|
|
|
|
|
|
2019-09-15 23:31:31 +03:00
|
|
|
|
class ROCAUCScore(object):
|
|
|
|
|
"""
|
|
|
|
|
An AUC ROC score.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.golds = []
|
|
|
|
|
self.cands = []
|
|
|
|
|
self.saved_score = 0.0
|
|
|
|
|
self.saved_score_at_len = 0
|
|
|
|
|
|
|
|
|
|
def score_set(self, cand, gold):
|
|
|
|
|
self.cands.append(cand)
|
|
|
|
|
self.golds.append(gold)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def score(self):
|
|
|
|
|
if len(self.golds) == self.saved_score_at_len:
|
|
|
|
|
return self.saved_score
|
|
|
|
|
try:
|
|
|
|
|
self.saved_score = _roc_auc_score(self.golds, self.cands)
|
|
|
|
|
# catch ValueError: Only one class present in y_true.
|
|
|
|
|
# ROC AUC score is not defined in that case.
|
2019-09-18 20:57:08 +03:00
|
|
|
|
except ValueError:
|
2019-09-15 23:31:31 +03:00
|
|
|
|
self.saved_score = -float("inf")
|
|
|
|
|
self.saved_score_at_len = len(self.golds)
|
|
|
|
|
return self.saved_score
|
|
|
|
|
|
|
|
|
|
|
2015-03-11 04:07:03 +03:00
|
|
|
|
class Scorer(object):
|
2019-05-24 15:06:04 +03:00
|
|
|
|
"""Compute evaluation scores."""
|
|
|
|
|
|
2019-09-15 23:31:31 +03:00
|
|
|
|
def __init__(self, eval_punct=False, pipeline=None):
|
2019-05-24 15:06:04 +03:00
|
|
|
|
"""Initialize the Scorer.
|
|
|
|
|
|
|
|
|
|
eval_punct (bool): Evaluate the dependency attachments to and from
|
|
|
|
|
punctuation.
|
|
|
|
|
RETURNS (Scorer): The newly created object.
|
|
|
|
|
|
|
|
|
|
DOCS: https://spacy.io/api/scorer#init
|
|
|
|
|
"""
|
2015-05-24 21:07:18 +03:00
|
|
|
|
self.tokens = PRFScore()
|
|
|
|
|
self.sbd = PRFScore()
|
|
|
|
|
self.unlabelled = PRFScore()
|
|
|
|
|
self.labelled = PRFScore()
|
2019-10-31 23:18:16 +03:00
|
|
|
|
self.labelled_per_dep = dict()
|
2015-05-24 21:07:18 +03:00
|
|
|
|
self.tags = PRFScore()
|
2020-04-02 15:46:32 +03:00
|
|
|
|
self.pos = PRFScore()
|
|
|
|
|
self.morphs = PRFScore()
|
|
|
|
|
self.morphs_per_feat = dict()
|
2019-11-28 13:10:07 +03:00
|
|
|
|
self.sent_starts = PRFScore()
|
2015-05-24 21:07:18 +03:00
|
|
|
|
self.ner = PRFScore()
|
2019-07-09 21:54:59 +03:00
|
|
|
|
self.ner_per_ents = dict()
|
2015-03-11 04:07:03 +03:00
|
|
|
|
self.eval_punct = eval_punct
|
2020-06-12 03:02:07 +03:00
|
|
|
|
self.textcat = PRFScore()
|
|
|
|
|
self.textcat_f_per_cat = dict()
|
|
|
|
|
self.textcat_auc_per_cat = dict()
|
2019-09-15 23:31:31 +03:00
|
|
|
|
self.textcat_positive_label = None
|
|
|
|
|
self.textcat_multilabel = False
|
|
|
|
|
|
|
|
|
|
if pipeline:
|
2020-06-12 03:02:07 +03:00
|
|
|
|
for name, component in pipeline:
|
2019-09-15 23:31:31 +03:00
|
|
|
|
if name == "textcat":
|
2020-06-12 03:02:07 +03:00
|
|
|
|
self.textcat_multilabel = component.model.attrs["multi_label"]
|
2020-06-20 15:15:04 +03:00
|
|
|
|
self.textcat_positive_label = component.cfg.get(
|
|
|
|
|
"positive_label", None
|
|
|
|
|
)
|
2020-06-12 03:02:07 +03:00
|
|
|
|
for label in component.cfg.get("labels", []):
|
|
|
|
|
self.textcat_auc_per_cat[label] = ROCAUCScore()
|
|
|
|
|
self.textcat_f_per_cat[label] = PRFScore()
|
2015-03-11 04:07:03 +03:00
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def tags_acc(self):
|
2019-05-24 15:06:04 +03:00
|
|
|
|
"""RETURNS (float): Part-of-speech tag accuracy (fine grained tags,
|
|
|
|
|
i.e. `Token.tag`).
|
|
|
|
|
"""
|
2015-05-24 21:07:18 +03:00
|
|
|
|
return self.tags.fscore * 100
|
2015-05-24 03:49:56 +03:00
|
|
|
|
|
2020-04-02 15:46:32 +03:00
|
|
|
|
@property
|
|
|
|
|
def pos_acc(self):
|
|
|
|
|
"""RETURNS (float): Part-of-speech tag accuracy (coarse grained pos,
|
|
|
|
|
i.e. `Token.pos`).
|
|
|
|
|
"""
|
|
|
|
|
return self.pos.fscore * 100
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def morphs_acc(self):
|
2020-06-20 15:15:04 +03:00
|
|
|
|
"""RETURNS (float): Morph tag accuracy (morphological features,
|
2020-04-02 15:46:32 +03:00
|
|
|
|
i.e. `Token.morph`).
|
|
|
|
|
"""
|
2020-06-20 15:15:04 +03:00
|
|
|
|
return self.morphs.fscore * 100
|
2020-04-02 15:46:32 +03:00
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def morphs_per_type(self):
|
2020-06-20 15:15:04 +03:00
|
|
|
|
"""RETURNS (dict): Scores per dependency label.
|
2020-04-02 15:46:32 +03:00
|
|
|
|
"""
|
2020-06-20 15:15:04 +03:00
|
|
|
|
return {
|
|
|
|
|
k: {"p": v.precision * 100, "r": v.recall * 100, "f": v.fscore * 100}
|
|
|
|
|
for k, v in self.morphs_per_feat.items()
|
|
|
|
|
}
|
2020-04-02 15:46:32 +03:00
|
|
|
|
|
2019-11-28 13:10:07 +03:00
|
|
|
|
@property
|
|
|
|
|
def sent_p(self):
|
|
|
|
|
"""RETURNS (float): F-score for identification of sentence starts.
|
|
|
|
|
i.e. `Token.is_sent_start`).
|
|
|
|
|
"""
|
|
|
|
|
return self.sent_starts.precision * 100
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def sent_r(self):
|
|
|
|
|
"""RETURNS (float): F-score for identification of sentence starts.
|
|
|
|
|
i.e. `Token.is_sent_start`).
|
|
|
|
|
"""
|
|
|
|
|
return self.sent_starts.recall * 100
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def sent_f(self):
|
|
|
|
|
"""RETURNS (float): F-score for identification of sentence starts.
|
|
|
|
|
i.e. `Token.is_sent_start`).
|
|
|
|
|
"""
|
|
|
|
|
return self.sent_starts.fscore * 100
|
|
|
|
|
|
2015-05-24 03:49:56 +03:00
|
|
|
|
@property
|
|
|
|
|
def token_acc(self):
|
2019-05-24 15:06:04 +03:00
|
|
|
|
"""RETURNS (float): Tokenization accuracy."""
|
2015-06-28 07:21:38 +03:00
|
|
|
|
return self.tokens.precision * 100
|
2015-03-11 04:07:03 +03:00
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def uas(self):
|
2019-05-24 15:06:04 +03:00
|
|
|
|
"""RETURNS (float): Unlabelled dependency score."""
|
2015-05-24 21:07:18 +03:00
|
|
|
|
return self.unlabelled.fscore * 100
|
2015-03-11 04:07:03 +03:00
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def las(self):
|
2019-10-31 23:18:16 +03:00
|
|
|
|
"""RETURNS (float): Labelled dependency score."""
|
2015-05-24 21:07:18 +03:00
|
|
|
|
return self.labelled.fscore * 100
|
2015-03-11 04:07:03 +03:00
|
|
|
|
|
2019-10-31 23:18:16 +03:00
|
|
|
|
@property
|
|
|
|
|
def las_per_type(self):
|
|
|
|
|
"""RETURNS (dict): Scores per dependency label.
|
|
|
|
|
"""
|
|
|
|
|
return {
|
|
|
|
|
k: {"p": v.precision * 100, "r": v.recall * 100, "f": v.fscore * 100}
|
|
|
|
|
for k, v in self.labelled_per_dep.items()
|
|
|
|
|
}
|
|
|
|
|
|
2015-03-11 04:07:03 +03:00
|
|
|
|
@property
|
|
|
|
|
def ents_p(self):
|
2019-05-24 15:06:04 +03:00
|
|
|
|
"""RETURNS (float): Named entity accuracy (precision)."""
|
2015-05-27 04:18:16 +03:00
|
|
|
|
return self.ner.precision * 100
|
2015-03-11 04:07:03 +03:00
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def ents_r(self):
|
2019-05-24 15:06:04 +03:00
|
|
|
|
"""RETURNS (float): Named entity accuracy (recall)."""
|
2015-05-27 04:18:16 +03:00
|
|
|
|
return self.ner.recall * 100
|
2015-04-19 11:31:31 +03:00
|
|
|
|
|
2015-03-11 04:07:03 +03:00
|
|
|
|
@property
|
|
|
|
|
def ents_f(self):
|
2019-05-24 15:06:04 +03:00
|
|
|
|
"""RETURNS (float): Named entity accuracy (F-score)."""
|
2015-05-27 04:18:16 +03:00
|
|
|
|
return self.ner.fscore * 100
|
2015-03-11 04:07:03 +03:00
|
|
|
|
|
2019-07-10 12:19:28 +03:00
|
|
|
|
@property
|
|
|
|
|
def ents_per_type(self):
|
|
|
|
|
"""RETURNS (dict): Scores per entity label.
|
|
|
|
|
"""
|
|
|
|
|
return {
|
|
|
|
|
k: {"p": v.precision * 100, "r": v.recall * 100, "f": v.fscore * 100}
|
|
|
|
|
for k, v in self.ner_per_ents.items()
|
|
|
|
|
}
|
|
|
|
|
|
2019-09-15 23:31:31 +03:00
|
|
|
|
@property
|
2020-06-12 03:02:07 +03:00
|
|
|
|
def textcat_f(self):
|
|
|
|
|
"""RETURNS (float): f-score on positive label for binary classification,
|
|
|
|
|
macro-averaged f-score for multilabel classification
|
2019-09-15 23:31:31 +03:00
|
|
|
|
"""
|
|
|
|
|
if not self.textcat_multilabel:
|
|
|
|
|
if self.textcat_positive_label:
|
2020-06-12 03:02:07 +03:00
|
|
|
|
# binary classification
|
2019-09-15 23:31:31 +03:00
|
|
|
|
return self.textcat.fscore * 100
|
2020-06-12 03:02:07 +03:00
|
|
|
|
# multi-class and/or multi-label
|
|
|
|
|
return (
|
|
|
|
|
sum([score.fscore for label, score in self.textcat_f_per_cat.items()])
|
|
|
|
|
/ (len(self.textcat_f_per_cat) + 1e-100)
|
|
|
|
|
* 100
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def textcat_auc(self):
|
|
|
|
|
"""RETURNS (float): macro-averaged AUC ROC score for multilabel classification (-1 if undefined)
|
|
|
|
|
"""
|
2019-09-15 23:31:31 +03:00
|
|
|
|
return max(
|
2020-06-12 03:02:07 +03:00
|
|
|
|
sum([score.score for label, score in self.textcat_auc_per_cat.items()])
|
|
|
|
|
/ (len(self.textcat_auc_per_cat) + 1e-100),
|
2019-09-15 23:31:31 +03:00
|
|
|
|
-1,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@property
|
2020-06-12 03:02:07 +03:00
|
|
|
|
def textcats_auc_per_cat(self):
|
|
|
|
|
"""RETURNS (dict): AUC ROC Scores per textcat label.
|
2019-09-15 23:31:31 +03:00
|
|
|
|
"""
|
|
|
|
|
return {
|
|
|
|
|
k: {"roc_auc_score": max(v.score, -1)}
|
2020-06-12 03:02:07 +03:00
|
|
|
|
for k, v in self.textcat_auc_per_cat.items()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def textcats_f_per_cat(self):
|
|
|
|
|
"""RETURNS (dict): F-scores per textcat label.
|
|
|
|
|
"""
|
|
|
|
|
return {
|
|
|
|
|
k: {"p": v.precision * 100, "r": v.recall * 100, "f": v.fscore * 100}
|
|
|
|
|
for k, v in self.textcat_f_per_cat.items()
|
2019-09-15 23:31:31 +03:00
|
|
|
|
}
|
|
|
|
|
|
2016-10-09 13:24:24 +03:00
|
|
|
|
@property
|
|
|
|
|
def scores(self):
|
2020-06-12 03:02:07 +03:00
|
|
|
|
"""RETURNS (dict): All scores mapped by key.
|
2019-05-24 15:06:04 +03:00
|
|
|
|
"""
|
2016-10-09 13:24:24 +03:00
|
|
|
|
return {
|
💫 Tidy up and auto-format .py files (#2983)
<!--- Provide a general summary of your changes in the title. -->
## Description
- [x] Use [`black`](https://github.com/ambv/black) to auto-format all `.py` files.
- [x] Update flake8 config to exclude very large files (lemmatization tables etc.)
- [x] Update code to be compatible with flake8 rules
- [x] Fix various small bugs, inconsistencies and messy stuff in the language data
- [x] Update docs to explain new code style (`black`, `flake8`, when to use `# fmt: off` and `# fmt: on` and what `# noqa` means)
Once #2932 is merged, which auto-formats and tidies up the CLI, we'll be able to run `flake8 spacy` actually get meaningful results.
At the moment, the code style and linting isn't applied automatically, but I'm hoping that the new [GitHub Actions](https://github.com/features/actions) will let us auto-format pull requests and post comments with relevant linting information.
### Types of change
enhancement, code style
## Checklist
<!--- Before you submit the PR, go over this checklist and make sure you can
tick off all the boxes. [] -> [x] -->
- [x] I have submitted the spaCy Contributor Agreement.
- [x] I ran the tests, and all new and existing tests passed.
- [x] My changes don't require a change to the documentation, or if they do, I've added all required information.
2018-11-30 19:03:03 +03:00
|
|
|
|
"uas": self.uas,
|
|
|
|
|
"las": self.las,
|
2019-10-31 23:18:16 +03:00
|
|
|
|
"las_per_type": self.las_per_type,
|
💫 Tidy up and auto-format .py files (#2983)
<!--- Provide a general summary of your changes in the title. -->
## Description
- [x] Use [`black`](https://github.com/ambv/black) to auto-format all `.py` files.
- [x] Update flake8 config to exclude very large files (lemmatization tables etc.)
- [x] Update code to be compatible with flake8 rules
- [x] Fix various small bugs, inconsistencies and messy stuff in the language data
- [x] Update docs to explain new code style (`black`, `flake8`, when to use `# fmt: off` and `# fmt: on` and what `# noqa` means)
Once #2932 is merged, which auto-formats and tidies up the CLI, we'll be able to run `flake8 spacy` actually get meaningful results.
At the moment, the code style and linting isn't applied automatically, but I'm hoping that the new [GitHub Actions](https://github.com/features/actions) will let us auto-format pull requests and post comments with relevant linting information.
### Types of change
enhancement, code style
## Checklist
<!--- Before you submit the PR, go over this checklist and make sure you can
tick off all the boxes. [] -> [x] -->
- [x] I have submitted the spaCy Contributor Agreement.
- [x] I ran the tests, and all new and existing tests passed.
- [x] My changes don't require a change to the documentation, or if they do, I've added all required information.
2018-11-30 19:03:03 +03:00
|
|
|
|
"ents_p": self.ents_p,
|
|
|
|
|
"ents_r": self.ents_r,
|
|
|
|
|
"ents_f": self.ents_f,
|
2019-07-10 12:19:28 +03:00
|
|
|
|
"ents_per_type": self.ents_per_type,
|
💫 Tidy up and auto-format .py files (#2983)
<!--- Provide a general summary of your changes in the title. -->
## Description
- [x] Use [`black`](https://github.com/ambv/black) to auto-format all `.py` files.
- [x] Update flake8 config to exclude very large files (lemmatization tables etc.)
- [x] Update code to be compatible with flake8 rules
- [x] Fix various small bugs, inconsistencies and messy stuff in the language data
- [x] Update docs to explain new code style (`black`, `flake8`, when to use `# fmt: off` and `# fmt: on` and what `# noqa` means)
Once #2932 is merged, which auto-formats and tidies up the CLI, we'll be able to run `flake8 spacy` actually get meaningful results.
At the moment, the code style and linting isn't applied automatically, but I'm hoping that the new [GitHub Actions](https://github.com/features/actions) will let us auto-format pull requests and post comments with relevant linting information.
### Types of change
enhancement, code style
## Checklist
<!--- Before you submit the PR, go over this checklist and make sure you can
tick off all the boxes. [] -> [x] -->
- [x] I have submitted the spaCy Contributor Agreement.
- [x] I ran the tests, and all new and existing tests passed.
- [x] My changes don't require a change to the documentation, or if they do, I've added all required information.
2018-11-30 19:03:03 +03:00
|
|
|
|
"tags_acc": self.tags_acc,
|
2020-04-02 15:46:32 +03:00
|
|
|
|
"pos_acc": self.pos_acc,
|
|
|
|
|
"morphs_acc": self.morphs_acc,
|
|
|
|
|
"morphs_per_type": self.morphs_per_type,
|
2019-11-28 13:10:07 +03:00
|
|
|
|
"sent_p": self.sent_p,
|
|
|
|
|
"sent_r": self.sent_r,
|
|
|
|
|
"sent_f": self.sent_f,
|
💫 Tidy up and auto-format .py files (#2983)
<!--- Provide a general summary of your changes in the title. -->
## Description
- [x] Use [`black`](https://github.com/ambv/black) to auto-format all `.py` files.
- [x] Update flake8 config to exclude very large files (lemmatization tables etc.)
- [x] Update code to be compatible with flake8 rules
- [x] Fix various small bugs, inconsistencies and messy stuff in the language data
- [x] Update docs to explain new code style (`black`, `flake8`, when to use `# fmt: off` and `# fmt: on` and what `# noqa` means)
Once #2932 is merged, which auto-formats and tidies up the CLI, we'll be able to run `flake8 spacy` actually get meaningful results.
At the moment, the code style and linting isn't applied automatically, but I'm hoping that the new [GitHub Actions](https://github.com/features/actions) will let us auto-format pull requests and post comments with relevant linting information.
### Types of change
enhancement, code style
## Checklist
<!--- Before you submit the PR, go over this checklist and make sure you can
tick off all the boxes. [] -> [x] -->
- [x] I have submitted the spaCy Contributor Agreement.
- [x] I ran the tests, and all new and existing tests passed.
- [x] My changes don't require a change to the documentation, or if they do, I've added all required information.
2018-11-30 19:03:03 +03:00
|
|
|
|
"token_acc": self.token_acc,
|
2020-06-12 03:02:07 +03:00
|
|
|
|
"textcat_f": self.textcat_f,
|
|
|
|
|
"textcat_auc": self.textcat_auc,
|
|
|
|
|
"textcats_f_per_cat": self.textcats_f_per_cat,
|
|
|
|
|
"textcats_auc_per_cat": self.textcats_auc_per_cat,
|
2016-10-09 13:24:24 +03:00
|
|
|
|
}
|
|
|
|
|
|
2019-11-11 19:35:27 +03:00
|
|
|
|
def score(self, example, verbose=False, punct_labels=("p", "punct")):
|
2020-06-26 20:34:12 +03:00
|
|
|
|
"""Update the evaluation scores from a single Example.
|
2019-05-24 15:06:04 +03:00
|
|
|
|
|
2019-11-11 19:35:27 +03:00
|
|
|
|
example (Example): The predicted annotations + correct annotations.
|
2019-05-24 15:06:04 +03:00
|
|
|
|
verbose (bool): Print debugging information.
|
|
|
|
|
punct_labels (tuple): Dependency labels for punctuation. Used to
|
|
|
|
|
evaluate dependency attachments to punctuation if `eval_punct` is
|
|
|
|
|
`True`.
|
|
|
|
|
|
|
|
|
|
DOCS: https://spacy.io/api/scorer#score
|
|
|
|
|
"""
|
2020-06-26 20:34:12 +03:00
|
|
|
|
doc = example.predicted
|
|
|
|
|
gold_doc = example.reference
|
|
|
|
|
align = example.alignment
|
2015-05-24 21:07:18 +03:00
|
|
|
|
gold_deps = set()
|
2019-10-31 23:18:16 +03:00
|
|
|
|
gold_deps_per_dep = {}
|
2015-05-24 21:07:18 +03:00
|
|
|
|
gold_tags = set()
|
2020-04-02 15:46:32 +03:00
|
|
|
|
gold_pos = set()
|
|
|
|
|
gold_morphs = set()
|
|
|
|
|
gold_morphs_per_feat = {}
|
2019-11-28 13:10:07 +03:00
|
|
|
|
gold_sent_starts = set()
|
2020-06-26 20:34:12 +03:00
|
|
|
|
for gold_i, token in enumerate(gold_doc):
|
|
|
|
|
gold_tags.add((gold_i, token.tag_))
|
|
|
|
|
gold_pos.add((gold_i, token.pos_))
|
|
|
|
|
gold_morphs.add((gold_i, token.morph_))
|
|
|
|
|
if token.morph_:
|
|
|
|
|
for feat in token.morph_.split("|"):
|
2020-04-02 15:46:32 +03:00
|
|
|
|
field, values = feat.split("=")
|
|
|
|
|
if field not in self.morphs_per_feat:
|
|
|
|
|
self.morphs_per_feat[field] = PRFScore()
|
|
|
|
|
if field not in gold_morphs_per_feat:
|
|
|
|
|
gold_morphs_per_feat[field] = set()
|
2020-06-26 20:34:12 +03:00
|
|
|
|
gold_morphs_per_feat[field].add((gold_i, feat))
|
|
|
|
|
if token.sent_start:
|
|
|
|
|
gold_sent_starts.add(gold_i)
|
|
|
|
|
dep = token.dep_.lower()
|
|
|
|
|
if dep not in punct_labels:
|
|
|
|
|
gold_deps.add((gold_i, token.head.i, dep))
|
|
|
|
|
if dep not in self.labelled_per_dep:
|
|
|
|
|
self.labelled_per_dep[dep] = PRFScore()
|
|
|
|
|
if dep not in gold_deps_per_dep:
|
|
|
|
|
gold_deps_per_dep[dep] = set()
|
|
|
|
|
gold_deps_per_dep[dep].add((gold_i, token.head.i, dep))
|
2015-05-24 21:07:18 +03:00
|
|
|
|
cand_deps = set()
|
2019-10-31 23:18:16 +03:00
|
|
|
|
cand_deps_per_dep = {}
|
2015-05-24 21:07:18 +03:00
|
|
|
|
cand_tags = set()
|
2020-04-02 15:46:32 +03:00
|
|
|
|
cand_pos = set()
|
|
|
|
|
cand_morphs = set()
|
|
|
|
|
cand_morphs_per_feat = {}
|
2019-11-28 13:10:07 +03:00
|
|
|
|
cand_sent_starts = set()
|
2019-05-24 15:06:04 +03:00
|
|
|
|
for token in doc:
|
2015-06-07 20:10:32 +03:00
|
|
|
|
if token.orth_.isspace():
|
|
|
|
|
continue
|
2020-06-26 20:34:12 +03:00
|
|
|
|
gold_i = align.cand_to_gold[token.i]
|
2015-05-30 19:24:32 +03:00
|
|
|
|
if gold_i is None:
|
2018-03-27 20:23:02 +03:00
|
|
|
|
self.tokens.fp += 1
|
2015-05-30 19:24:32 +03:00
|
|
|
|
else:
|
2015-06-28 07:21:38 +03:00
|
|
|
|
self.tokens.tp += 1
|
2015-05-30 19:24:32 +03:00
|
|
|
|
cand_tags.add((gold_i, token.tag_))
|
2020-04-02 15:46:32 +03:00
|
|
|
|
cand_pos.add((gold_i, token.pos_))
|
|
|
|
|
cand_morphs.add((gold_i, token.morph_))
|
|
|
|
|
if token.morph_:
|
|
|
|
|
for feat in token.morph_.split("|"):
|
|
|
|
|
field, values = feat.split("=")
|
|
|
|
|
if field not in self.morphs_per_feat:
|
|
|
|
|
self.morphs_per_feat[field] = PRFScore()
|
|
|
|
|
if field not in cand_morphs_per_feat:
|
|
|
|
|
cand_morphs_per_feat[field] = set()
|
|
|
|
|
cand_morphs_per_feat[field].add((gold_i, feat))
|
2019-11-28 13:10:07 +03:00
|
|
|
|
if token.is_sent_start:
|
|
|
|
|
cand_sent_starts.add(gold_i)
|
2016-02-03 00:59:06 +03:00
|
|
|
|
if token.dep_.lower() not in punct_labels and token.orth_.strip():
|
2020-06-26 20:34:12 +03:00
|
|
|
|
gold_head = align.cand_to_gold[token.head.i]
|
2015-05-24 21:07:18 +03:00
|
|
|
|
# None is indistinct, so we can't just add it to the set
|
|
|
|
|
# Multiple (None, None) deps are possible
|
|
|
|
|
if gold_i is None or gold_head is None:
|
|
|
|
|
self.unlabelled.fp += 1
|
|
|
|
|
self.labelled.fp += 1
|
|
|
|
|
else:
|
2015-05-27 04:18:16 +03:00
|
|
|
|
cand_deps.add((gold_i, gold_head, token.dep_.lower()))
|
2019-10-31 23:18:16 +03:00
|
|
|
|
if token.dep_.lower() not in self.labelled_per_dep:
|
|
|
|
|
self.labelled_per_dep[token.dep_.lower()] = PRFScore()
|
|
|
|
|
if token.dep_.lower() not in cand_deps_per_dep:
|
|
|
|
|
cand_deps_per_dep[token.dep_.lower()] = set()
|
2019-11-20 15:15:24 +03:00
|
|
|
|
cand_deps_per_dep[token.dep_.lower()].add(
|
|
|
|
|
(gold_i, gold_head, token.dep_.lower())
|
|
|
|
|
)
|
2020-06-26 20:34:12 +03:00
|
|
|
|
# Find all NER labels in gold and doc
|
|
|
|
|
ent_labels = set(
|
|
|
|
|
[k.label_ for k in gold_doc.ents] + [k.label_ for k in doc.ents]
|
|
|
|
|
)
|
|
|
|
|
# Set up all labels for per type scoring and prepare gold per type
|
|
|
|
|
gold_per_ents = {ent_label: set() for ent_label in ent_labels}
|
|
|
|
|
for ent_label in ent_labels:
|
|
|
|
|
if ent_label not in self.ner_per_ents:
|
|
|
|
|
self.ner_per_ents[ent_label] = PRFScore()
|
|
|
|
|
# Find all candidate labels, for all and per type
|
|
|
|
|
gold_ents = set()
|
|
|
|
|
cand_ents = set()
|
|
|
|
|
# If we have missing values in the gold, we can't easily tell whether
|
|
|
|
|
# our NER predictions are true.
|
|
|
|
|
# It seems bad but it's what we've always done.
|
|
|
|
|
if all(token.ent_iob != 0 for token in gold_doc):
|
|
|
|
|
for ent in gold_doc.ents:
|
|
|
|
|
gold_ent = (ent.label_, ent.start, ent.end - 1)
|
|
|
|
|
gold_ents.add(gold_ent)
|
|
|
|
|
gold_per_ents[ent.label_].add((ent.label_, ent.start, ent.end - 1))
|
2019-08-01 18:15:36 +03:00
|
|
|
|
cand_per_ents = {ent_label: set() for ent_label in ent_labels}
|
2019-05-24 15:06:04 +03:00
|
|
|
|
for ent in doc.ents:
|
2020-06-26 20:34:12 +03:00
|
|
|
|
first = align.cand_to_gold[ent.start]
|
|
|
|
|
last = align.cand_to_gold[ent.end - 1]
|
2015-05-28 23:39:08 +03:00
|
|
|
|
if first is None or last is None:
|
|
|
|
|
self.ner.fp += 1
|
2019-07-09 21:54:59 +03:00
|
|
|
|
self.ner_per_ents[ent.label_].fp += 1
|
2015-05-28 23:39:08 +03:00
|
|
|
|
else:
|
|
|
|
|
cand_ents.add((ent.label_, first, last))
|
2019-08-01 18:15:36 +03:00
|
|
|
|
cand_per_ents[ent.label_].add((ent.label_, first, last))
|
2019-07-09 21:54:59 +03:00
|
|
|
|
# Scores per ent
|
2019-08-01 18:15:36 +03:00
|
|
|
|
for k, v in self.ner_per_ents.items():
|
|
|
|
|
if k in cand_per_ents:
|
|
|
|
|
v.score_set(cand_per_ents[k], gold_per_ents[k])
|
2019-07-09 21:54:59 +03:00
|
|
|
|
# Score for all ents
|
2015-05-28 23:39:08 +03:00
|
|
|
|
self.ner.score_set(cand_ents, gold_ents)
|
2015-05-27 04:18:16 +03:00
|
|
|
|
self.tags.score_set(cand_tags, gold_tags)
|
2020-04-02 15:46:32 +03:00
|
|
|
|
self.pos.score_set(cand_pos, gold_pos)
|
|
|
|
|
self.morphs.score_set(cand_morphs, gold_morphs)
|
|
|
|
|
for field in self.morphs_per_feat:
|
2020-06-20 15:15:04 +03:00
|
|
|
|
self.morphs_per_feat[field].score_set(
|
|
|
|
|
cand_morphs_per_feat.get(field, set()),
|
|
|
|
|
gold_morphs_per_feat.get(field, set()),
|
|
|
|
|
)
|
2019-11-28 13:10:07 +03:00
|
|
|
|
self.sent_starts.score_set(cand_sent_starts, gold_sent_starts)
|
2015-05-24 21:07:18 +03:00
|
|
|
|
self.labelled.score_set(cand_deps, gold_deps)
|
2019-10-31 23:18:16 +03:00
|
|
|
|
for dep in self.labelled_per_dep:
|
2019-11-20 15:15:24 +03:00
|
|
|
|
self.labelled_per_dep[dep].score_set(
|
|
|
|
|
cand_deps_per_dep.get(dep, set()), gold_deps_per_dep.get(dep, set())
|
|
|
|
|
)
|
2015-05-24 21:07:18 +03:00
|
|
|
|
self.unlabelled.score_set(
|
💫 Tidy up and auto-format .py files (#2983)
<!--- Provide a general summary of your changes in the title. -->
## Description
- [x] Use [`black`](https://github.com/ambv/black) to auto-format all `.py` files.
- [x] Update flake8 config to exclude very large files (lemmatization tables etc.)
- [x] Update code to be compatible with flake8 rules
- [x] Fix various small bugs, inconsistencies and messy stuff in the language data
- [x] Update docs to explain new code style (`black`, `flake8`, when to use `# fmt: off` and `# fmt: on` and what `# noqa` means)
Once #2932 is merged, which auto-formats and tidies up the CLI, we'll be able to run `flake8 spacy` actually get meaningful results.
At the moment, the code style and linting isn't applied automatically, but I'm hoping that the new [GitHub Actions](https://github.com/features/actions) will let us auto-format pull requests and post comments with relevant linting information.
### Types of change
enhancement, code style
## Checklist
<!--- Before you submit the PR, go over this checklist and make sure you can
tick off all the boxes. [] -> [x] -->
- [x] I have submitted the spaCy Contributor Agreement.
- [x] I ran the tests, and all new and existing tests passed.
- [x] My changes don't require a change to the documentation, or if they do, I've added all required information.
2018-11-30 19:03:03 +03:00
|
|
|
|
set(item[:2] for item in cand_deps), set(item[:2] for item in gold_deps)
|
2015-05-24 21:07:18 +03:00
|
|
|
|
)
|
2019-09-15 23:31:31 +03:00
|
|
|
|
if (
|
2020-06-26 20:34:12 +03:00
|
|
|
|
len(gold_doc.cats) > 0
|
2020-06-20 15:15:04 +03:00
|
|
|
|
and set(self.textcat_f_per_cat)
|
|
|
|
|
== set(self.textcat_auc_per_cat)
|
2020-06-26 20:34:12 +03:00
|
|
|
|
== set(gold_doc.cats)
|
|
|
|
|
and set(gold_doc.cats) == set(doc.cats)
|
2019-09-15 23:31:31 +03:00
|
|
|
|
):
|
2020-06-26 20:34:12 +03:00
|
|
|
|
goldcat = max(gold_doc.cats, key=gold_doc.cats.get)
|
2019-09-15 23:31:31 +03:00
|
|
|
|
candcat = max(doc.cats, key=doc.cats.get)
|
|
|
|
|
if self.textcat_positive_label:
|
|
|
|
|
self.textcat.score_set(
|
|
|
|
|
set([self.textcat_positive_label]) & set([candcat]),
|
|
|
|
|
set([self.textcat_positive_label]) & set([goldcat]),
|
|
|
|
|
)
|
2020-06-26 20:34:12 +03:00
|
|
|
|
for label in set(gold_doc.cats):
|
2020-06-12 03:02:07 +03:00
|
|
|
|
self.textcat_auc_per_cat[label].score_set(
|
2020-06-26 20:34:12 +03:00
|
|
|
|
doc.cats[label], gold_doc.cats[label]
|
2020-06-12 03:02:07 +03:00
|
|
|
|
)
|
|
|
|
|
self.textcat_f_per_cat[label].score_set(
|
2020-06-20 15:15:04 +03:00
|
|
|
|
set([label]) & set([candcat]), set([label]) & set([goldcat])
|
2020-06-12 03:02:07 +03:00
|
|
|
|
)
|
|
|
|
|
elif len(self.textcat_f_per_cat) > 0:
|
|
|
|
|
model_labels = set(self.textcat_f_per_cat)
|
2020-06-26 20:34:12 +03:00
|
|
|
|
eval_labels = set(gold_doc.cats)
|
2020-06-12 03:02:07 +03:00
|
|
|
|
raise ValueError(
|
|
|
|
|
Errors.E162.format(model_labels=model_labels, eval_labels=eval_labels)
|
|
|
|
|
)
|
|
|
|
|
elif len(self.textcat_auc_per_cat) > 0:
|
|
|
|
|
model_labels = set(self.textcat_auc_per_cat)
|
2020-06-26 20:34:12 +03:00
|
|
|
|
eval_labels = set(gold_doc.cats)
|
2019-09-15 23:31:31 +03:00
|
|
|
|
raise ValueError(
|
|
|
|
|
Errors.E162.format(model_labels=model_labels, eval_labels=eval_labels)
|
|
|
|
|
)
|
2015-06-14 18:45:50 +03:00
|
|
|
|
if verbose:
|
2020-06-26 20:34:12 +03:00
|
|
|
|
gold_words = gold_doc.words
|
💫 Tidy up and auto-format .py files (#2983)
<!--- Provide a general summary of your changes in the title. -->
## Description
- [x] Use [`black`](https://github.com/ambv/black) to auto-format all `.py` files.
- [x] Update flake8 config to exclude very large files (lemmatization tables etc.)
- [x] Update code to be compatible with flake8 rules
- [x] Fix various small bugs, inconsistencies and messy stuff in the language data
- [x] Update docs to explain new code style (`black`, `flake8`, when to use `# fmt: off` and `# fmt: on` and what `# noqa` means)
Once #2932 is merged, which auto-formats and tidies up the CLI, we'll be able to run `flake8 spacy` actually get meaningful results.
At the moment, the code style and linting isn't applied automatically, but I'm hoping that the new [GitHub Actions](https://github.com/features/actions) will let us auto-format pull requests and post comments with relevant linting information.
### Types of change
enhancement, code style
## Checklist
<!--- Before you submit the PR, go over this checklist and make sure you can
tick off all the boxes. [] -> [x] -->
- [x] I have submitted the spaCy Contributor Agreement.
- [x] I ran the tests, and all new and existing tests passed.
- [x] My changes don't require a change to the documentation, or if they do, I've added all required information.
2018-11-30 19:03:03 +03:00
|
|
|
|
for w_id, h_id, dep in cand_deps - gold_deps:
|
|
|
|
|
print("F", gold_words[w_id], dep, gold_words[h_id])
|
|
|
|
|
for w_id, h_id, dep in gold_deps - cand_deps:
|
|
|
|
|
print("M", gold_words[w_id], dep, gold_words[h_id])
|
2019-09-15 23:31:31 +03:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#############################################################################
|
|
|
|
|
#
|
|
|
|
|
# The following implementation of roc_auc_score() is adapted from
|
|
|
|
|
# scikit-learn, which is distributed under the following license:
|
|
|
|
|
#
|
|
|
|
|
# New BSD License
|
|
|
|
|
#
|
|
|
|
|
# Copyright (c) 2007–2019 The scikit-learn developers.
|
|
|
|
|
# All rights reserved.
|
|
|
|
|
#
|
|
|
|
|
#
|
|
|
|
|
# Redistribution and use in source and binary forms, with or without
|
|
|
|
|
# modification, are permitted provided that the following conditions are met:
|
|
|
|
|
#
|
|
|
|
|
# a. Redistributions of source code must retain the above copyright notice,
|
|
|
|
|
# this list of conditions and the following disclaimer.
|
|
|
|
|
# b. Redistributions in binary form must reproduce the above copyright
|
|
|
|
|
# notice, this list of conditions and the following disclaimer in the
|
|
|
|
|
# documentation and/or other materials provided with the distribution.
|
|
|
|
|
# c. Neither the name of the Scikit-learn Developers nor the names of
|
|
|
|
|
# its contributors may be used to endorse or promote products
|
|
|
|
|
# derived from this software without specific prior written
|
2019-09-18 20:56:55 +03:00
|
|
|
|
# permission.
|
2019-09-15 23:31:31 +03:00
|
|
|
|
#
|
|
|
|
|
#
|
|
|
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
|
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
|
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
|
|
|
|
# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR
|
|
|
|
|
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
|
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
|
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
|
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
|
|
|
|
|
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
|
|
|
|
|
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
|
|
|
|
|
# DAMAGE.
|
|
|
|
|
|
2019-09-18 20:56:55 +03:00
|
|
|
|
|
2019-09-15 23:31:31 +03:00
|
|
|
|
def _roc_auc_score(y_true, y_score):
|
|
|
|
|
"""Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC)
|
|
|
|
|
from prediction scores.
|
|
|
|
|
|
|
|
|
|
Note: this implementation is restricted to the binary classification task
|
|
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
|
----------
|
|
|
|
|
y_true : array, shape = [n_samples] or [n_samples, n_classes]
|
|
|
|
|
True binary labels or binary label indicators.
|
|
|
|
|
The multiclass case expects shape = [n_samples] and labels
|
|
|
|
|
with values in ``range(n_classes)``.
|
|
|
|
|
|
|
|
|
|
y_score : array, shape = [n_samples] or [n_samples, n_classes]
|
|
|
|
|
Target scores, can either be probability estimates of the positive
|
|
|
|
|
class, confidence values, or non-thresholded measure of decisions
|
|
|
|
|
(as returned by "decision_function" on some classifiers). For binary
|
|
|
|
|
y_true, y_score is supposed to be the score of the class with greater
|
|
|
|
|
label. The multiclass case expects shape = [n_samples, n_classes]
|
|
|
|
|
where the scores correspond to probability estimates.
|
|
|
|
|
|
|
|
|
|
Returns
|
|
|
|
|
-------
|
|
|
|
|
auc : float
|
|
|
|
|
|
|
|
|
|
References
|
|
|
|
|
----------
|
|
|
|
|
.. [1] `Wikipedia entry for the Receiver operating characteristic
|
|
|
|
|
<https://en.wikipedia.org/wiki/Receiver_operating_characteristic>`_
|
|
|
|
|
|
|
|
|
|
.. [2] Fawcett T. An introduction to ROC analysis[J]. Pattern Recognition
|
|
|
|
|
Letters, 2006, 27(8):861-874.
|
|
|
|
|
|
|
|
|
|
.. [3] `Analyzing a portion of the ROC curve. McClish, 1989
|
|
|
|
|
<https://www.ncbi.nlm.nih.gov/pubmed/2668680>`_
|
|
|
|
|
"""
|
|
|
|
|
if len(np.unique(y_true)) != 2:
|
|
|
|
|
raise ValueError(Errors.E165)
|
|
|
|
|
fpr, tpr, _ = _roc_curve(y_true, y_score)
|
|
|
|
|
return _auc(fpr, tpr)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _roc_curve(y_true, y_score):
|
|
|
|
|
"""Compute Receiver operating characteristic (ROC)
|
|
|
|
|
|
|
|
|
|
Note: this implementation is restricted to the binary classification task.
|
|
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
|
----------
|
|
|
|
|
|
|
|
|
|
y_true : array, shape = [n_samples]
|
|
|
|
|
True binary labels. If labels are not either {-1, 1} or {0, 1}, then
|
|
|
|
|
pos_label should be explicitly given.
|
|
|
|
|
|
|
|
|
|
y_score : array, shape = [n_samples]
|
|
|
|
|
Target scores, can either be probability estimates of the positive
|
|
|
|
|
class, confidence values, or non-thresholded measure of decisions
|
|
|
|
|
(as returned by "decision_function" on some classifiers).
|
|
|
|
|
|
|
|
|
|
Returns
|
|
|
|
|
-------
|
|
|
|
|
fpr : array, shape = [>2]
|
|
|
|
|
Increasing false positive rates such that element i is the false
|
|
|
|
|
positive rate of predictions with score >= thresholds[i].
|
|
|
|
|
|
|
|
|
|
tpr : array, shape = [>2]
|
|
|
|
|
Increasing true positive rates such that element i is the true
|
|
|
|
|
positive rate of predictions with score >= thresholds[i].
|
|
|
|
|
|
|
|
|
|
thresholds : array, shape = [n_thresholds]
|
|
|
|
|
Decreasing thresholds on the decision function used to compute
|
|
|
|
|
fpr and tpr. `thresholds[0]` represents no instances being predicted
|
|
|
|
|
and is arbitrarily set to `max(y_score) + 1`.
|
|
|
|
|
|
|
|
|
|
Notes
|
|
|
|
|
-----
|
|
|
|
|
Since the thresholds are sorted from low to high values, they
|
|
|
|
|
are reversed upon returning them to ensure they correspond to both ``fpr``
|
|
|
|
|
and ``tpr``, which are sorted in reversed order during their calculation.
|
|
|
|
|
|
|
|
|
|
References
|
|
|
|
|
----------
|
|
|
|
|
.. [1] `Wikipedia entry for the Receiver operating characteristic
|
|
|
|
|
<https://en.wikipedia.org/wiki/Receiver_operating_characteristic>`_
|
|
|
|
|
|
|
|
|
|
.. [2] Fawcett T. An introduction to ROC analysis[J]. Pattern Recognition
|
|
|
|
|
Letters, 2006, 27(8):861-874.
|
|
|
|
|
"""
|
|
|
|
|
fps, tps, thresholds = _binary_clf_curve(y_true, y_score)
|
|
|
|
|
|
|
|
|
|
# Add an extra threshold position
|
|
|
|
|
# to make sure that the curve starts at (0, 0)
|
|
|
|
|
tps = np.r_[0, tps]
|
|
|
|
|
fps = np.r_[0, fps]
|
|
|
|
|
thresholds = np.r_[thresholds[0] + 1, thresholds]
|
|
|
|
|
|
|
|
|
|
if fps[-1] <= 0:
|
|
|
|
|
fpr = np.repeat(np.nan, fps.shape)
|
|
|
|
|
else:
|
|
|
|
|
fpr = fps / fps[-1]
|
|
|
|
|
|
|
|
|
|
if tps[-1] <= 0:
|
|
|
|
|
tpr = np.repeat(np.nan, tps.shape)
|
|
|
|
|
else:
|
|
|
|
|
tpr = tps / tps[-1]
|
|
|
|
|
|
|
|
|
|
return fpr, tpr, thresholds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _binary_clf_curve(y_true, y_score):
|
|
|
|
|
"""Calculate true and false positives per binary classification threshold.
|
|
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
|
----------
|
|
|
|
|
y_true : array, shape = [n_samples]
|
|
|
|
|
True targets of binary classification
|
|
|
|
|
|
|
|
|
|
y_score : array, shape = [n_samples]
|
|
|
|
|
Estimated probabilities or decision function
|
|
|
|
|
|
|
|
|
|
Returns
|
|
|
|
|
-------
|
|
|
|
|
fps : array, shape = [n_thresholds]
|
|
|
|
|
A count of false positives, at index i being the number of negative
|
|
|
|
|
samples assigned a score >= thresholds[i]. The total number of
|
|
|
|
|
negative samples is equal to fps[-1] (thus true negatives are given by
|
|
|
|
|
fps[-1] - fps).
|
|
|
|
|
|
|
|
|
|
tps : array, shape = [n_thresholds <= len(np.unique(y_score))]
|
|
|
|
|
An increasing count of true positives, at index i being the number
|
|
|
|
|
of positive samples assigned a score >= thresholds[i]. The total
|
|
|
|
|
number of positive samples is equal to tps[-1] (thus false negatives
|
|
|
|
|
are given by tps[-1] - tps).
|
|
|
|
|
|
|
|
|
|
thresholds : array, shape = [n_thresholds]
|
|
|
|
|
Decreasing score values.
|
|
|
|
|
"""
|
2019-09-18 20:56:55 +03:00
|
|
|
|
pos_label = 1.0
|
2019-09-15 23:31:31 +03:00
|
|
|
|
|
|
|
|
|
y_true = np.ravel(y_true)
|
|
|
|
|
y_score = np.ravel(y_score)
|
|
|
|
|
|
|
|
|
|
# make y_true a boolean vector
|
2019-09-18 20:56:55 +03:00
|
|
|
|
y_true = y_true == pos_label
|
2019-09-15 23:31:31 +03:00
|
|
|
|
|
|
|
|
|
# sort scores and corresponding truth values
|
|
|
|
|
desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1]
|
|
|
|
|
y_score = y_score[desc_score_indices]
|
|
|
|
|
y_true = y_true[desc_score_indices]
|
2019-09-18 20:56:55 +03:00
|
|
|
|
weight = 1.0
|
2019-09-15 23:31:31 +03:00
|
|
|
|
|
|
|
|
|
# y_score typically has many tied values. Here we extract
|
|
|
|
|
# the indices associated with the distinct values. We also
|
|
|
|
|
# concatenate a value for the end of the curve.
|
|
|
|
|
distinct_value_indices = np.where(np.diff(y_score))[0]
|
|
|
|
|
threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1]
|
|
|
|
|
|
|
|
|
|
# accumulate the true positives with decreasing threshold
|
|
|
|
|
tps = _stable_cumsum(y_true * weight)[threshold_idxs]
|
|
|
|
|
fps = 1 + threshold_idxs - tps
|
|
|
|
|
return fps, tps, y_score[threshold_idxs]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _stable_cumsum(arr, axis=None, rtol=1e-05, atol=1e-08):
|
|
|
|
|
"""Use high precision for cumsum and check that final value matches sum
|
|
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
|
----------
|
|
|
|
|
arr : array-like
|
|
|
|
|
To be cumulatively summed as flat
|
|
|
|
|
axis : int, optional
|
|
|
|
|
Axis along which the cumulative sum is computed.
|
|
|
|
|
The default (None) is to compute the cumsum over the flattened array.
|
|
|
|
|
rtol : float
|
|
|
|
|
Relative tolerance, see ``np.allclose``
|
|
|
|
|
atol : float
|
|
|
|
|
Absolute tolerance, see ``np.allclose``
|
|
|
|
|
"""
|
|
|
|
|
out = np.cumsum(arr, axis=axis, dtype=np.float64)
|
|
|
|
|
expected = np.sum(arr, axis=axis, dtype=np.float64)
|
2019-09-18 20:56:55 +03:00
|
|
|
|
if not np.all(
|
|
|
|
|
np.isclose(
|
|
|
|
|
out.take(-1, axis=axis), expected, rtol=rtol, atol=atol, equal_nan=True
|
|
|
|
|
)
|
|
|
|
|
):
|
2019-09-15 23:31:31 +03:00
|
|
|
|
raise ValueError(Errors.E163)
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _auc(x, y):
|
|
|
|
|
"""Compute Area Under the Curve (AUC) using the trapezoidal rule
|
|
|
|
|
|
|
|
|
|
This is a general function, given points on a curve. For computing the
|
|
|
|
|
area under the ROC-curve, see :func:`roc_auc_score`.
|
|
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
|
----------
|
|
|
|
|
x : array, shape = [n]
|
|
|
|
|
x coordinates. These must be either monotonic increasing or monotonic
|
|
|
|
|
decreasing.
|
|
|
|
|
y : array, shape = [n]
|
|
|
|
|
y coordinates.
|
|
|
|
|
|
|
|
|
|
Returns
|
|
|
|
|
-------
|
|
|
|
|
auc : float
|
|
|
|
|
"""
|
|
|
|
|
x = np.ravel(x)
|
|
|
|
|
y = np.ravel(y)
|
|
|
|
|
|
|
|
|
|
direction = 1
|
|
|
|
|
dx = np.diff(x)
|
|
|
|
|
if np.any(dx < 0):
|
|
|
|
|
if np.all(dx <= 0):
|
|
|
|
|
direction = -1
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(Errors.E164.format(x))
|
|
|
|
|
|
|
|
|
|
area = direction * np.trapz(y, x)
|
|
|
|
|
if isinstance(area, np.memmap):
|
|
|
|
|
# Reductions such as .sum used internally in np.trapz do not return a
|
|
|
|
|
# scalar by default for numpy.memmap instances contrary to
|
|
|
|
|
# regular numpy.ndarray instances.
|
|
|
|
|
area = area.dtype.type(area)
|
|
|
|
|
return area
|