diff --git a/spacy/tests/serialize/test_serialize_config.py b/spacy/tests/serialize/test_serialize_config.py index 85e6f8b2c..7d96c9063 100644 --- a/spacy/tests/serialize/test_serialize_config.py +++ b/spacy/tests/serialize/test_serialize_config.py @@ -415,6 +415,27 @@ def test_config_overrides(): assert nlp.pipe_names == ["tok2vec", "tagger"] +def test_config_overrides_registered_functions(): + nlp = spacy.blank("en") + nlp.add_pipe("attribute_ruler") + with make_tempdir() as d: + nlp.to_disk(d) + nlp = spacy.load( + d, + config={ + "components": { + "attribute_ruler": { + "scorer": {"@scorers": "spacy.tagger_scorer.v1"} + } + } + }, + ) + assert ( + nlp.config["components"]["attribute_ruler"]["scorer"]["@scorers"] + == "spacy.tagger_scorer.v1" + ) + + def test_config_interpolation(): config = Config().from_str(nlp_config_string, interpolate=False) assert config["corpora"]["train"]["path"] == "${paths.train}" diff --git a/spacy/util.py b/spacy/util.py index 8cc89217d..ac38fbc90 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -1530,10 +1530,14 @@ def set_dot_to_object(config: Config, section: str, value: Any) -> None: def walk_dict( node: Dict[str, Any], parent: List[str] = [] ) -> Iterator[Tuple[List[str], Any]]: - """Walk a dict and yield the path and values of the leaves.""" + """Walk a dict and yield the path and values of the leaves, treating + registered functions that start with @ as final values rather than dicts to + traverse.""" for key, value in node.items(): key_parent = [*parent, key] - if isinstance(value, dict): + if isinstance(value, dict) and not any( + value_key.startswith("@") for value_key in value + ): yield from walk_dict(value, key_parent) else: yield (key_parent, value)