Fix types in Scorer

This commit is contained in:
Adriane Boyd 2020-07-29 10:39:33 +02:00
parent ac24adec73
commit c689ae8f0a

View File

@ -2,7 +2,7 @@ from typing import Optional, Iterable, Dict, Any, Callable, Tuple, TYPE_CHECKING
import numpy as np
from .gold import Example
from .tokens import Token
from .tokens import Token, Doc
from .errors import Errors
from .util import get_lang_class
from .morphology import Morphology
@ -49,7 +49,7 @@ class PRFScore:
class ROCAUCScore:
"""An AUC ROC score."""
def __init__(self):
def __init__(self) -> None:
self.golds = []
self.cands = []
self.saved_score = 0.0
@ -250,14 +250,14 @@ class Scorer:
examples: Iterable[Example],
attr: str,
*,
getter: Callable[[Token, str], Any] = getattr,
getter: Callable[[Doc, str], Any] = getattr,
**cfg,
) -> Dict[str, Any]:
"""Returns PRF scores for labeled spans.
examples (Iterable[Example]): Examples to score
attr (str): The attribute to score.
getter (Callable[[Token, str], Any]): Defaults to getattr. If provided,
getter (Callable[[Doc, str], Any]): Defaults to getattr. If provided,
getter(doc, attr) should return the spans for the individual doc.
RETURNS (Dict[str, Any]): A dictionary containing the PRF scores under
the keys attr_p/r/f and the per-type PRF scores under attr_per_type.
@ -315,7 +315,7 @@ class Scorer:
examples: Iterable[Example],
attr: str,
*,
getter: Callable[[Token, str], Any] = getattr,
getter: Callable[[Doc, str], Any] = getattr,
labels: Iterable[str] = tuple(),
multi_label: bool = True,
positive_label: Optional[str] = None,
@ -327,7 +327,7 @@ class Scorer:
examples (Iterable[Example]): Examples to score
attr (str): The attribute to score.
getter (Callable[[Token, str], Any]): Defaults to getattr. If provided,
getter (Callable[[Doc, str], Any]): Defaults to getattr. If provided,
getter(doc, attr) should return the values for the individual doc.
labels (Iterable[str]): The set of possible labels. Defaults to [].
multi_label (bool): Whether the attribute allows multiple labels.