Update nlp.config (workaround).

This commit is contained in:
Raphael Mitsch 2022-09-02 16:25:57 +02:00
parent 24b69a1be8
commit 73432c6bfb

View File

@ -1,5 +1,4 @@
import functools import functools
from functools import partial
import operator import operator
from pathlib import Path from pathlib import Path
import logging import logging
@ -7,6 +6,7 @@ from typing import Optional, Tuple, Any, Dict, List
import numpy import numpy
import wasabi.tables import wasabi.tables
from confection import Config
from ..pipeline import TrainablePipe, Pipe from ..pipeline import TrainablePipe, Pipe
from ..errors import Errors from ..errors import Errors
@ -121,7 +121,6 @@ def find_threshold(
raise TypeError(Errors.E1044) raise TypeError(Errors.E1044)
if not hasattr(pipe, "scorer"): if not hasattr(pipe, "scorer"):
raise AttributeError(Errors.E1045) raise AttributeError(Errors.E1045)
setattr(pipe, "scorer", partial(pipe.scorer.func, beta=beta))
if not silent: if not silent:
wasabi.msg.info( wasabi.msg.info(
@ -150,6 +149,7 @@ def find_threshold(
scores: Dict[float, float] = {} scores: Dict[float, float] = {}
for threshold in numpy.linspace(0, 1, n_trials): for threshold in numpy.linspace(0, 1, n_trials):
pipe.cfg = set_nested_item(pipe.cfg, config_keys, threshold) pipe.cfg = set_nested_item(pipe.cfg, config_keys, threshold)
nlp._pipe_configs[pipe_name] = Config(pipe.cfg)
scores[threshold] = nlp.evaluate(dev_dataset)[scores_key] scores[threshold] = nlp.evaluate(dev_dataset)[scores_key]
if not ( if not (
isinstance(scores[threshold], float) or isinstance(scores[threshold], int) isinstance(scores[threshold], float) or isinstance(scores[threshold], int)