Fix mypy errors.

This commit is contained in:
Raphael Mitsch 2022-08-09 10:03:43 +02:00
parent a7b56e82cf
commit d689d97ab5
2 changed files with 13 additions and 11 deletions

View File

@ -1,10 +1,11 @@
from pathlib import Path
import logging
from typing import Optional, Tuple, Union
from typing import Optional, Tuple, Union, Dict, cast
import numpy
import wasabi.tables
from pipeline import Pipe
from ._util import app, Arg, Opt
from .. import util
from ..pipeline import MultiLabel_TextCategorizer
@ -60,10 +61,10 @@ def find_threshold(
model_path: Union[str, Path],
doc_path: Union[str, Path],
*,
average: str = _DEFAULTS["average"],
pipe_name: Optional[str] = _DEFAULTS["pipe_name"],
n_trials: int = _DEFAULTS["n_trials"],
beta: float = _DEFAULTS["beta"],
average: str = _DEFAULTS["average"], # type: ignore
pipe_name: Optional[str] = _DEFAULTS["pipe_name"], # type: ignore
n_trials: int = _DEFAULTS["n_trials"], # type: ignore
beta: float = _DEFAULTS["beta"], # type: ignore
verbose: bool = True,
) -> Tuple[float, float]:
"""
@ -80,7 +81,7 @@ def find_threshold(
"""
nlp = util.load_model(model_path)
pipe: Optional[MultiLabel_TextCategorizer] = None
pipe: Optional[Pipe] = None
selected_pipe_name: Optional[str] = pipe_name
if average not in ("micro", "macro"):
@ -99,7 +100,6 @@ def find_threshold(
exits=1,
)
pipe = _pipe
print(pipe_name, _pipe_name, pipe.labels)
break
elif pipe_name is None:
if isinstance(_pipe, MultiLabel_TextCategorizer):
@ -121,6 +121,8 @@ def find_threshold(
"No component of type `MultiLabel_TextCategorizer` found in pipeline.",
exits=1,
)
# This is purely for MyPy. Type checking is done in loop above already.
assert isinstance(pipe, MultiLabel_TextCategorizer)
if verbose:
print(
@ -134,8 +136,8 @@ def find_threshold(
t: {True: ref_pos_counts.copy(), False: ref_pos_counts.copy()}
for t in thresholds
}
f_scores_per_label = {t: ref_pos_counts.copy() for t in thresholds}
f_scores = {t: 0 for t in thresholds}
f_scores_per_label = {t: {label: 0.0 for label in pipe.labels} for t in thresholds}
f_scores = {t: 0.0 for t in thresholds}
# Count true/false positives for provided docs.
doc_bin = DocBin()
@ -196,7 +198,7 @@ def find_threshold(
[f_scores_per_label[threshold][label] for label in ref_pos_counts]
) / len(ref_pos_counts)
best_threshold = max(f_scores, key=f_scores.get)
best_threshold = max(f_scores.keys(), key=(lambda key: f_scores[key]))
if verbose:
print(
f"Best threshold: {round(best_threshold, ndigits=4)} with F-score of {f_scores[best_threshold]}.",

View File

@ -876,7 +876,7 @@ def test_cli_find_threshold(capsys):
]
def init_nlp(
component_factory_names: Tuple[str] = (),
component_factory_names: Tuple[str, ...] = (),
) -> Tuple[Language, List[Example]]:
_nlp = English()