Remove assumption of component being a Pipe object or having a .cfg attribute.

This commit is contained in:
Raphael Mitsch 2022-09-05 12:19:36 +02:00
parent 20c4a0d613
commit 03666f6e4e

View File

@ -7,7 +7,6 @@ from typing import Optional, Tuple, Any, Dict, List
import numpy
import wasabi.tables
from ..pipeline import TrainablePipe, Pipe
from ..errors import Errors
from ..training import Corpus
from ._util import app, Arg, Opt, import_code, setup_gpu
@ -111,14 +110,11 @@ def find_threshold(
wasabi.msg.fail("Evaluation data not found", data_path, exits=1)
nlp = util.load_model(model)
pipe: Optional[Pipe] = None
pipe: Optional[Any] = None
try:
pipe = nlp.get_pipe(pipe_name)
except KeyError as err:
wasabi.msg.fail(title=str(err), exits=1)
if not isinstance(pipe, TrainablePipe):
raise TypeError(Errors.E1044)
if not hasattr(pipe, "scorer"):
raise AttributeError(Errors.E1045)
@ -170,10 +166,16 @@ def find_threshold(
threshold,
),
)
nlp.get_pipe(pipe_name).cfg = set_nested_item(pipe.cfg, config_keys, threshold)
if hasattr(pipe, "cfg"):
setattr(
nlp.get_pipe(pipe_name),
"cfg",
set_nested_item(getattr(pipe, "cfg"), config_keys, threshold),
)
scores[threshold] = nlp.evaluate(dev_dataset)[scores_key]
if not (
isinstance(scores[threshold], float) or isinstance(scores[threshold], int)
if not isinstance(scores[threshold], float) and not isinstance(
scores[threshold], int
):
wasabi.msg.fail(
f"Returned score for key '{scores_key}' is not numeric. Threshold optimization only works for numeric "