Fix mypy errors.

This commit is contained in:
Raphael Mitsch 2022-08-09 10:03:43 +02:00
parent a7b56e82cf
commit d689d97ab5
2 changed files with 13 additions and 11 deletions

View File

@ -1,10 +1,11 @@
from pathlib import Path from pathlib import Path
import logging import logging
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union, Dict, cast
import numpy import numpy
import wasabi.tables import wasabi.tables
from pipeline import Pipe
from ._util import app, Arg, Opt from ._util import app, Arg, Opt
from .. import util from .. import util
from ..pipeline import MultiLabel_TextCategorizer from ..pipeline import MultiLabel_TextCategorizer
@ -60,10 +61,10 @@ def find_threshold(
model_path: Union[str, Path], model_path: Union[str, Path],
doc_path: Union[str, Path], doc_path: Union[str, Path],
*, *,
average: str = _DEFAULTS["average"], average: str = _DEFAULTS["average"], # type: ignore
pipe_name: Optional[str] = _DEFAULTS["pipe_name"], pipe_name: Optional[str] = _DEFAULTS["pipe_name"], # type: ignore
n_trials: int = _DEFAULTS["n_trials"], n_trials: int = _DEFAULTS["n_trials"], # type: ignore
beta: float = _DEFAULTS["beta"], beta: float = _DEFAULTS["beta"], # type: ignore
verbose: bool = True, verbose: bool = True,
) -> Tuple[float, float]: ) -> Tuple[float, float]:
""" """
@ -80,7 +81,7 @@ def find_threshold(
""" """
nlp = util.load_model(model_path) nlp = util.load_model(model_path)
pipe: Optional[MultiLabel_TextCategorizer] = None pipe: Optional[Pipe] = None
selected_pipe_name: Optional[str] = pipe_name selected_pipe_name: Optional[str] = pipe_name
if average not in ("micro", "macro"): if average not in ("micro", "macro"):
@ -99,7 +100,6 @@ def find_threshold(
exits=1, exits=1,
) )
pipe = _pipe pipe = _pipe
print(pipe_name, _pipe_name, pipe.labels)
break break
elif pipe_name is None: elif pipe_name is None:
if isinstance(_pipe, MultiLabel_TextCategorizer): if isinstance(_pipe, MultiLabel_TextCategorizer):
@ -121,6 +121,8 @@ def find_threshold(
"No component of type `MultiLabel_TextCategorizer` found in pipeline.", "No component of type `MultiLabel_TextCategorizer` found in pipeline.",
exits=1, exits=1,
) )
# This is purely for MyPy. Type checking is done in loop above already.
assert isinstance(pipe, MultiLabel_TextCategorizer)
if verbose: if verbose:
print( print(
@ -134,8 +136,8 @@ def find_threshold(
t: {True: ref_pos_counts.copy(), False: ref_pos_counts.copy()} t: {True: ref_pos_counts.copy(), False: ref_pos_counts.copy()}
for t in thresholds for t in thresholds
} }
f_scores_per_label = {t: ref_pos_counts.copy() for t in thresholds} f_scores_per_label = {t: {label: 0.0 for label in pipe.labels} for t in thresholds}
f_scores = {t: 0 for t in thresholds} f_scores = {t: 0.0 for t in thresholds}
# Count true/false positives for provided docs. # Count true/false positives for provided docs.
doc_bin = DocBin() doc_bin = DocBin()
@ -196,7 +198,7 @@ def find_threshold(
[f_scores_per_label[threshold][label] for label in ref_pos_counts] [f_scores_per_label[threshold][label] for label in ref_pos_counts]
) / len(ref_pos_counts) ) / len(ref_pos_counts)
best_threshold = max(f_scores, key=f_scores.get) best_threshold = max(f_scores.keys(), key=(lambda key: f_scores[key]))
if verbose: if verbose:
print( print(
f"Best threshold: {round(best_threshold, ndigits=4)} with F-score of {f_scores[best_threshold]}.", f"Best threshold: {round(best_threshold, ndigits=4)} with F-score of {f_scores[best_threshold]}.",

View File

@ -876,7 +876,7 @@ def test_cli_find_threshold(capsys):
] ]
def init_nlp( def init_nlp(
component_factory_names: Tuple[str] = (), component_factory_names: Tuple[str, ...] = (),
) -> Tuple[Language, List[Example]]: ) -> Tuple[Language, List[Example]]:
_nlp = English() _nlp = English()