Fix types in Scorer

This commit is contained in:
Adriane Boyd 2020-07-29 10:39:33 +02:00
parent ac24adec73
commit c689ae8f0a

View File

@ -2,7 +2,7 @@ from typing import Optional, Iterable, Dict, Any, Callable, Tuple, TYPE_CHECKING
import numpy as np import numpy as np
from .gold import Example from .gold import Example
from .tokens import Token from .tokens import Token, Doc
from .errors import Errors from .errors import Errors
from .util import get_lang_class from .util import get_lang_class
from .morphology import Morphology from .morphology import Morphology
@ -49,7 +49,7 @@ class PRFScore:
class ROCAUCScore: class ROCAUCScore:
"""An AUC ROC score.""" """An AUC ROC score."""
def __init__(self): def __init__(self) -> None:
self.golds = [] self.golds = []
self.cands = [] self.cands = []
self.saved_score = 0.0 self.saved_score = 0.0
@ -250,14 +250,14 @@ class Scorer:
examples: Iterable[Example], examples: Iterable[Example],
attr: str, attr: str,
*, *,
getter: Callable[[Token, str], Any] = getattr, getter: Callable[[Doc, str], Any] = getattr,
**cfg, **cfg,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Returns PRF scores for labeled spans. """Returns PRF scores for labeled spans.
examples (Iterable[Example]): Examples to score examples (Iterable[Example]): Examples to score
attr (str): The attribute to score. attr (str): The attribute to score.
getter (Callable[[Token, str], Any]): Defaults to getattr. If provided, getter (Callable[[Doc, str], Any]): Defaults to getattr. If provided,
getter(doc, attr) should return the spans for the individual doc. getter(doc, attr) should return the spans for the individual doc.
RETURNS (Dict[str, Any]): A dictionary containing the PRF scores under RETURNS (Dict[str, Any]): A dictionary containing the PRF scores under
the keys attr_p/r/f and the per-type PRF scores under attr_per_type. the keys attr_p/r/f and the per-type PRF scores under attr_per_type.
@ -315,7 +315,7 @@ class Scorer:
examples: Iterable[Example], examples: Iterable[Example],
attr: str, attr: str,
*, *,
getter: Callable[[Token, str], Any] = getattr, getter: Callable[[Doc, str], Any] = getattr,
labels: Iterable[str] = tuple(), labels: Iterable[str] = tuple(),
multi_label: bool = True, multi_label: bool = True,
positive_label: Optional[str] = None, positive_label: Optional[str] = None,
@ -327,7 +327,7 @@ class Scorer:
examples (Iterable[Example]): Examples to score examples (Iterable[Example]): Examples to score
attr (str): The attribute to score. attr (str): The attribute to score.
getter (Callable[[Token, str], Any]): Defaults to getattr. If provided, getter (Callable[[Doc, str], Any]): Defaults to getattr. If provided,
getter(doc, attr) should return the values for the individual doc. getter(doc, attr) should return the values for the individual doc.
labels (Iterable[str]): The set of possible labels. Defaults to []. labels (Iterable[str]): The set of possible labels. Defaults to [].
multi_label (bool): Whether the attribute allows multiple labels. multi_label (bool): Whether the attribute allows multiple labels.