diff --git a/spacy/tests/serialize/test_serialize_config.py b/spacy/tests/serialize/test_serialize_config.py index 7d96c9063..d305d7719 100644 --- a/spacy/tests/serialize/test_serialize_config.py +++ b/spacy/tests/serialize/test_serialize_config.py @@ -420,7 +420,7 @@ def test_config_overrides_registered_functions(): nlp.add_pipe("attribute_ruler") with make_tempdir() as d: nlp.to_disk(d) - nlp = spacy.load( + nlp_re1 = spacy.load( d, config={ "components": { @@ -431,9 +431,26 @@ def test_config_overrides_registered_functions(): }, ) assert ( - nlp.config["components"]["attribute_ruler"]["scorer"]["@scorers"] + nlp_re1.config["components"]["attribute_ruler"]["scorer"]["@scorers"] == "spacy.tagger_scorer.v1" ) + nlp_re2 = spacy.load( + d, + config={ + "components": { + "attribute_ruler": { + "scorer": { + "@scorers": "spacy.overlapping_labeled_spans_scorer.v1", + "spans_key": "some_other_key", + } + } + } + }, + ) + assert ( + nlp_re2.config["components"]["attribute_ruler"]["scorer"]["spans_key"] + == "some_other_key" + ) def test_config_interpolation(): diff --git a/spacy/tests/test_misc.py b/spacy/tests/test_misc.py index 618f17334..23a66263d 100644 --- a/spacy/tests/test_misc.py +++ b/spacy/tests/test_misc.py @@ -237,6 +237,10 @@ def test_minor_version(a1, a2, b1, b2, is_match): {"training.batch_size": 128, "training.optimizer.learn_rate": 0.01}, {"training": {"batch_size": 128, "optimizer": {"learn_rate": 0.01}}}, ), + ( + {"attribute_ruler.scorer.@scorers": "spacy.tagger_scorer.v1"}, + {"attribute_ruler": {"scorer": {"@scorers": "spacy.tagger_scorer.v1"}}}, + ), ], ) def test_dot_to_dict(dot_notation, expected): @@ -245,6 +249,29 @@ def test_dot_to_dict(dot_notation, expected): assert util.dict_to_dot(result) == dot_notation +@pytest.mark.parametrize( + "dot_notation,expected", + [ + ( + {"token.pos": True, "token._.xyz": True}, + {"token": {"pos": True, "_": {"xyz": True}}}, + ), + ( + {"training.batch_size": 128, "training.optimizer.learn_rate": 0.01}, + {"training": {"batch_size": 128, "optimizer": {"learn_rate": 0.01}}}, + ), + ( + {"attribute_ruler.scorer": {"@scorers": "spacy.tagger_scorer.v1"}}, + {"attribute_ruler": {"scorer": {"@scorers": "spacy.tagger_scorer.v1"}}}, + ), + ], +) +def test_dot_to_dict_overrides(dot_notation, expected): + result = util.dot_to_dict(dot_notation) + assert result == expected + assert util.dict_to_dot(result, for_overrides=True) == dot_notation + + def test_set_dot_to_object(): config = {"foo": {"bar": 1, "baz": {"x": "y"}}, "test": {"a": {"b": "c"}}} with pytest.raises(KeyError): diff --git a/spacy/util.py b/spacy/util.py index ac38fbc90..9f1b886bf 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -511,7 +511,7 @@ def load_model_from_path( if not meta: meta = get_model_meta(model_path) config_path = model_path / "config.cfg" - overrides = dict_to_dot(config) + overrides = dict_to_dot(config, for_overrides=True) config = load_config(config_path, overrides=overrides) nlp = load_model_from_config( config, @@ -1479,14 +1479,19 @@ def dot_to_dict(values: Dict[str, Any]) -> Dict[str, dict]: return result -def dict_to_dot(obj: Dict[str, dict]) -> Dict[str, Any]: +def dict_to_dot(obj: Dict[str, dict], *, for_overrides: bool = False) -> Dict[str, Any]: """Convert dot notation to a dict. For example: {"token": {"pos": True, "_": {"xyz": True }}} becomes {"token.pos": True, "token._.xyz": True}. - values (Dict[str, dict]): The dict to convert. + obj (Dict[str, dict]): The dict to convert. + for_overrides (bool): Whether to enable special handling for registered + functions in overrides. RETURNS (Dict[str, Any]): The key/value pairs. """ - return {".".join(key): value for key, value in walk_dict(obj)} + return { + ".".join(key): value + for key, value in walk_dict(obj, for_overrides=for_overrides) + } def dot_to_object(config: Config, section: str): @@ -1528,17 +1533,20 @@ def set_dot_to_object(config: Config, section: str, value: Any) -> None: def walk_dict( - node: Dict[str, Any], parent: List[str] = [] + node: Dict[str, Any], parent: List[str] = [], *, for_overrides: bool = False ) -> Iterator[Tuple[List[str], Any]]: - """Walk a dict and yield the path and values of the leaves, treating - registered functions that start with @ as final values rather than dicts to - traverse.""" + """Walk a dict and yield the path and values of the leaves. + + for_overrides (bool): Whether to treat registered functions that start with + @ as final values rather than dicts to traverse. + """ for key, value in node.items(): key_parent = [*parent, key] - if isinstance(value, dict) and not any( - value_key.startswith("@") for value_key in value + if isinstance(value, dict) and ( + not for_overrides + or not any(value_key.startswith("@") for value_key in value) ): - yield from walk_dict(value, key_parent) + yield from walk_dict(value, key_parent, for_overrides=for_overrides) else: yield (key_parent, value)