Remove assumption of component being a Pipe object or having a .cfg attribute.

This commit is contained in:
Raphael Mitsch 2022-09-05 12:19:36 +02:00
parent 20c4a0d613
commit 03666f6e4e

View File

@ -7,7 +7,6 @@ from typing import Optional, Tuple, Any, Dict, List
import numpy import numpy
import wasabi.tables import wasabi.tables
from ..pipeline import TrainablePipe, Pipe
from ..errors import Errors from ..errors import Errors
from ..training import Corpus from ..training import Corpus
from ._util import app, Arg, Opt, import_code, setup_gpu from ._util import app, Arg, Opt, import_code, setup_gpu
@ -111,14 +110,11 @@ def find_threshold(
wasabi.msg.fail("Evaluation data not found", data_path, exits=1) wasabi.msg.fail("Evaluation data not found", data_path, exits=1)
nlp = util.load_model(model) nlp = util.load_model(model)
pipe: Optional[Pipe] = None pipe: Optional[Any] = None
try: try:
pipe = nlp.get_pipe(pipe_name) pipe = nlp.get_pipe(pipe_name)
except KeyError as err: except KeyError as err:
wasabi.msg.fail(title=str(err), exits=1) wasabi.msg.fail(title=str(err), exits=1)
if not isinstance(pipe, TrainablePipe):
raise TypeError(Errors.E1044)
if not hasattr(pipe, "scorer"): if not hasattr(pipe, "scorer"):
raise AttributeError(Errors.E1045) raise AttributeError(Errors.E1045)
@ -170,10 +166,16 @@ def find_threshold(
threshold, threshold,
), ),
) )
nlp.get_pipe(pipe_name).cfg = set_nested_item(pipe.cfg, config_keys, threshold) if hasattr(pipe, "cfg"):
setattr(
nlp.get_pipe(pipe_name),
"cfg",
set_nested_item(getattr(pipe, "cfg"), config_keys, threshold),
)
scores[threshold] = nlp.evaluate(dev_dataset)[scores_key] scores[threshold] = nlp.evaluate(dev_dataset)[scores_key]
if not ( if not isinstance(scores[threshold], float) and not isinstance(
isinstance(scores[threshold], float) or isinstance(scores[threshold], int) scores[threshold], int
): ):
wasabi.msg.fail( wasabi.msg.fail(
f"Returned score for key '{scores_key}' is not numeric. Threshold optimization only works for numeric " f"Returned score for key '{scores_key}' is not numeric. Threshold optimization only works for numeric "