Support overriding registered functions in configs

This commit is contained in:
Adriane Boyd 2023-05-11 13:54:28 +02:00
parent b5af0fe836
commit 5b6aed3e6c
2 changed files with 27 additions and 2 deletions

View File

@ -415,6 +415,27 @@ def test_config_overrides():
assert nlp.pipe_names == ["tok2vec", "tagger"]
def test_config_overrides_registered_functions():
nlp = spacy.blank("en")
nlp.add_pipe("attribute_ruler")
with make_tempdir() as d:
nlp.to_disk(d)
nlp = spacy.load(
d,
config={
"components": {
"attribute_ruler": {
"scorer": {"@scorers": "spacy.tagger_scorer.v1"}
}
}
},
)
assert (
nlp.config["components"]["attribute_ruler"]["scorer"]["@scorers"]
== "spacy.tagger_scorer.v1"
)
def test_config_interpolation():
config = Config().from_str(nlp_config_string, interpolate=False)
assert config["corpora"]["train"]["path"] == "${paths.train}"

View File

@ -1530,10 +1530,14 @@ def set_dot_to_object(config: Config, section: str, value: Any) -> None:
def walk_dict(
node: Dict[str, Any], parent: List[str] = []
) -> Iterator[Tuple[List[str], Any]]:
"""Walk a dict and yield the path and values of the leaves."""
"""Walk a dict and yield the path and values of the leaves, treating
registered functions that start with @ as final values rather than dicts to
traverse."""
for key, value in node.items():
key_parent = [*parent, key]
if isinstance(value, dict):
if isinstance(value, dict) and not any(
value_key.startswith("@") for value_key in value
):
yield from walk_dict(value, key_parent)
else:
yield (key_parent, value)