mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-08 14:14:57 +03:00
Fix mypy errors.
This commit is contained in:
parent
a7b56e82cf
commit
d689d97ab5
|
@ -1,10 +1,11 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union, Dict, cast
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import wasabi.tables
|
import wasabi.tables
|
||||||
|
|
||||||
|
from pipeline import Pipe
|
||||||
from ._util import app, Arg, Opt
|
from ._util import app, Arg, Opt
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..pipeline import MultiLabel_TextCategorizer
|
from ..pipeline import MultiLabel_TextCategorizer
|
||||||
|
@ -60,10 +61,10 @@ def find_threshold(
|
||||||
model_path: Union[str, Path],
|
model_path: Union[str, Path],
|
||||||
doc_path: Union[str, Path],
|
doc_path: Union[str, Path],
|
||||||
*,
|
*,
|
||||||
average: str = _DEFAULTS["average"],
|
average: str = _DEFAULTS["average"], # type: ignore
|
||||||
pipe_name: Optional[str] = _DEFAULTS["pipe_name"],
|
pipe_name: Optional[str] = _DEFAULTS["pipe_name"], # type: ignore
|
||||||
n_trials: int = _DEFAULTS["n_trials"],
|
n_trials: int = _DEFAULTS["n_trials"], # type: ignore
|
||||||
beta: float = _DEFAULTS["beta"],
|
beta: float = _DEFAULTS["beta"], # type: ignore
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
) -> Tuple[float, float]:
|
) -> Tuple[float, float]:
|
||||||
"""
|
"""
|
||||||
|
@ -80,7 +81,7 @@ def find_threshold(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
nlp = util.load_model(model_path)
|
nlp = util.load_model(model_path)
|
||||||
pipe: Optional[MultiLabel_TextCategorizer] = None
|
pipe: Optional[Pipe] = None
|
||||||
selected_pipe_name: Optional[str] = pipe_name
|
selected_pipe_name: Optional[str] = pipe_name
|
||||||
|
|
||||||
if average not in ("micro", "macro"):
|
if average not in ("micro", "macro"):
|
||||||
|
@ -99,7 +100,6 @@ def find_threshold(
|
||||||
exits=1,
|
exits=1,
|
||||||
)
|
)
|
||||||
pipe = _pipe
|
pipe = _pipe
|
||||||
print(pipe_name, _pipe_name, pipe.labels)
|
|
||||||
break
|
break
|
||||||
elif pipe_name is None:
|
elif pipe_name is None:
|
||||||
if isinstance(_pipe, MultiLabel_TextCategorizer):
|
if isinstance(_pipe, MultiLabel_TextCategorizer):
|
||||||
|
@ -121,6 +121,8 @@ def find_threshold(
|
||||||
"No component of type `MultiLabel_TextCategorizer` found in pipeline.",
|
"No component of type `MultiLabel_TextCategorizer` found in pipeline.",
|
||||||
exits=1,
|
exits=1,
|
||||||
)
|
)
|
||||||
|
# This is purely for MyPy. Type checking is done in loop above already.
|
||||||
|
assert isinstance(pipe, MultiLabel_TextCategorizer)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(
|
print(
|
||||||
|
@ -134,8 +136,8 @@ def find_threshold(
|
||||||
t: {True: ref_pos_counts.copy(), False: ref_pos_counts.copy()}
|
t: {True: ref_pos_counts.copy(), False: ref_pos_counts.copy()}
|
||||||
for t in thresholds
|
for t in thresholds
|
||||||
}
|
}
|
||||||
f_scores_per_label = {t: ref_pos_counts.copy() for t in thresholds}
|
f_scores_per_label = {t: {label: 0.0 for label in pipe.labels} for t in thresholds}
|
||||||
f_scores = {t: 0 for t in thresholds}
|
f_scores = {t: 0.0 for t in thresholds}
|
||||||
|
|
||||||
# Count true/false positives for provided docs.
|
# Count true/false positives for provided docs.
|
||||||
doc_bin = DocBin()
|
doc_bin = DocBin()
|
||||||
|
@ -196,7 +198,7 @@ def find_threshold(
|
||||||
[f_scores_per_label[threshold][label] for label in ref_pos_counts]
|
[f_scores_per_label[threshold][label] for label in ref_pos_counts]
|
||||||
) / len(ref_pos_counts)
|
) / len(ref_pos_counts)
|
||||||
|
|
||||||
best_threshold = max(f_scores, key=f_scores.get)
|
best_threshold = max(f_scores.keys(), key=(lambda key: f_scores[key]))
|
||||||
if verbose:
|
if verbose:
|
||||||
print(
|
print(
|
||||||
f"Best threshold: {round(best_threshold, ndigits=4)} with F-score of {f_scores[best_threshold]}.",
|
f"Best threshold: {round(best_threshold, ndigits=4)} with F-score of {f_scores[best_threshold]}.",
|
||||||
|
|
|
@ -876,7 +876,7 @@ def test_cli_find_threshold(capsys):
|
||||||
]
|
]
|
||||||
|
|
||||||
def init_nlp(
|
def init_nlp(
|
||||||
component_factory_names: Tuple[str] = (),
|
component_factory_names: Tuple[str, ...] = (),
|
||||||
) -> Tuple[Language, List[Example]]:
|
) -> Tuple[Language, List[Example]]:
|
||||||
_nlp = English()
|
_nlp = English()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user