mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-08 06:04:57 +03:00
Fix mypy errors.
This commit is contained in:
parent
a7b56e82cf
commit
d689d97ab5
|
@ -1,10 +1,11 @@
|
|||
from pathlib import Path
|
||||
import logging
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional, Tuple, Union, Dict, cast
|
||||
|
||||
import numpy
|
||||
import wasabi.tables
|
||||
|
||||
from pipeline import Pipe
|
||||
from ._util import app, Arg, Opt
|
||||
from .. import util
|
||||
from ..pipeline import MultiLabel_TextCategorizer
|
||||
|
@ -60,10 +61,10 @@ def find_threshold(
|
|||
model_path: Union[str, Path],
|
||||
doc_path: Union[str, Path],
|
||||
*,
|
||||
average: str = _DEFAULTS["average"],
|
||||
pipe_name: Optional[str] = _DEFAULTS["pipe_name"],
|
||||
n_trials: int = _DEFAULTS["n_trials"],
|
||||
beta: float = _DEFAULTS["beta"],
|
||||
average: str = _DEFAULTS["average"], # type: ignore
|
||||
pipe_name: Optional[str] = _DEFAULTS["pipe_name"], # type: ignore
|
||||
n_trials: int = _DEFAULTS["n_trials"], # type: ignore
|
||||
beta: float = _DEFAULTS["beta"], # type: ignore
|
||||
verbose: bool = True,
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
|
@ -80,7 +81,7 @@ def find_threshold(
|
|||
"""
|
||||
|
||||
nlp = util.load_model(model_path)
|
||||
pipe: Optional[MultiLabel_TextCategorizer] = None
|
||||
pipe: Optional[Pipe] = None
|
||||
selected_pipe_name: Optional[str] = pipe_name
|
||||
|
||||
if average not in ("micro", "macro"):
|
||||
|
@ -99,7 +100,6 @@ def find_threshold(
|
|||
exits=1,
|
||||
)
|
||||
pipe = _pipe
|
||||
print(pipe_name, _pipe_name, pipe.labels)
|
||||
break
|
||||
elif pipe_name is None:
|
||||
if isinstance(_pipe, MultiLabel_TextCategorizer):
|
||||
|
@ -121,6 +121,8 @@ def find_threshold(
|
|||
"No component of type `MultiLabel_TextCategorizer` found in pipeline.",
|
||||
exits=1,
|
||||
)
|
||||
# This is purely for MyPy. Type checking is done in loop above already.
|
||||
assert isinstance(pipe, MultiLabel_TextCategorizer)
|
||||
|
||||
if verbose:
|
||||
print(
|
||||
|
@ -134,8 +136,8 @@ def find_threshold(
|
|||
t: {True: ref_pos_counts.copy(), False: ref_pos_counts.copy()}
|
||||
for t in thresholds
|
||||
}
|
||||
f_scores_per_label = {t: ref_pos_counts.copy() for t in thresholds}
|
||||
f_scores = {t: 0 for t in thresholds}
|
||||
f_scores_per_label = {t: {label: 0.0 for label in pipe.labels} for t in thresholds}
|
||||
f_scores = {t: 0.0 for t in thresholds}
|
||||
|
||||
# Count true/false positives for provided docs.
|
||||
doc_bin = DocBin()
|
||||
|
@ -196,7 +198,7 @@ def find_threshold(
|
|||
[f_scores_per_label[threshold][label] for label in ref_pos_counts]
|
||||
) / len(ref_pos_counts)
|
||||
|
||||
best_threshold = max(f_scores, key=f_scores.get)
|
||||
best_threshold = max(f_scores.keys(), key=(lambda key: f_scores[key]))
|
||||
if verbose:
|
||||
print(
|
||||
f"Best threshold: {round(best_threshold, ndigits=4)} with F-score of {f_scores[best_threshold]}.",
|
||||
|
|
|
@ -876,7 +876,7 @@ def test_cli_find_threshold(capsys):
|
|||
]
|
||||
|
||||
def init_nlp(
|
||||
component_factory_names: Tuple[str] = (),
|
||||
component_factory_names: Tuple[str, ...] = (),
|
||||
) -> Tuple[Language, List[Example]]:
|
||||
_nlp = English()
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user