diff --git a/spacy/cli/find_threshold.py b/spacy/cli/find_threshold.py index 6ce4f7321..a451b7d4a 100644 --- a/spacy/cli/find_threshold.py +++ b/spacy/cli/find_threshold.py @@ -1,5 +1,4 @@ import functools -from functools import partial import operator from pathlib import Path import logging @@ -7,6 +6,7 @@ from typing import Optional, Tuple, Any, Dict, List import numpy import wasabi.tables +from confection import Config from ..pipeline import TrainablePipe, Pipe from ..errors import Errors @@ -121,7 +121,6 @@ def find_threshold( raise TypeError(Errors.E1044) if not hasattr(pipe, "scorer"): raise AttributeError(Errors.E1045) - setattr(pipe, "scorer", partial(pipe.scorer.func, beta=beta)) if not silent: wasabi.msg.info( @@ -150,6 +149,7 @@ def find_threshold( scores: Dict[float, float] = {} for threshold in numpy.linspace(0, 1, n_trials): pipe.cfg = set_nested_item(pipe.cfg, config_keys, threshold) + nlp._pipe_configs[pipe_name] = Config(pipe.cfg) scores[threshold] = nlp.evaluate(dev_dataset)[scores_key] if not ( isinstance(scores[threshold], float) or isinstance(scores[threshold], int)