mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Support overriding registered functions in configs (#12623)
Support overriding registered functions in configs. Previously the registry name was parsed as a section name rather than as a registry name.
This commit is contained in:
parent
c067b5264c
commit
65f6c9cd10
|
@ -13,6 +13,7 @@ from spacy.ml.models import (
|
||||||
build_Tok2Vec_model,
|
build_Tok2Vec_model,
|
||||||
)
|
)
|
||||||
from spacy.schemas import ConfigSchema, ConfigSchemaPretrain
|
from spacy.schemas import ConfigSchema, ConfigSchemaPretrain
|
||||||
|
from spacy.training import Example
|
||||||
from spacy.util import (
|
from spacy.util import (
|
||||||
load_config,
|
load_config,
|
||||||
load_config_from_str,
|
load_config_from_str,
|
||||||
|
@ -422,6 +423,55 @@ def test_config_overrides():
|
||||||
assert nlp.pipe_names == ["tok2vec", "tagger"]
|
assert nlp.pipe_names == ["tok2vec", "tagger"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.filterwarnings("ignore:\\[W036")
|
||||||
|
def test_config_overrides_registered_functions():
|
||||||
|
nlp = spacy.blank("en")
|
||||||
|
nlp.add_pipe("attribute_ruler")
|
||||||
|
with make_tempdir() as d:
|
||||||
|
nlp.to_disk(d)
|
||||||
|
nlp_re1 = spacy.load(
|
||||||
|
d,
|
||||||
|
config={
|
||||||
|
"components": {
|
||||||
|
"attribute_ruler": {
|
||||||
|
"scorer": {"@scorers": "spacy.tagger_scorer.v1"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
nlp_re1.config["components"]["attribute_ruler"]["scorer"]["@scorers"]
|
||||||
|
== "spacy.tagger_scorer.v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
@registry.misc("test_some_other_key")
|
||||||
|
def misc_some_other_key():
|
||||||
|
return "some_other_key"
|
||||||
|
|
||||||
|
nlp_re2 = spacy.load(
|
||||||
|
d,
|
||||||
|
config={
|
||||||
|
"components": {
|
||||||
|
"attribute_ruler": {
|
||||||
|
"scorer": {
|
||||||
|
"@scorers": "spacy.overlapping_labeled_spans_scorer.v1",
|
||||||
|
"spans_key": {"@misc": "test_some_other_key"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert nlp_re2.config["components"]["attribute_ruler"]["scorer"][
|
||||||
|
"spans_key"
|
||||||
|
] == {"@misc": "test_some_other_key"}
|
||||||
|
# run dummy evaluation (will return None scores) in order to test that
|
||||||
|
# the spans_key value in the nested override is working as intended in
|
||||||
|
# the config
|
||||||
|
example = Example.from_dict(nlp_re2.make_doc("a b c"), {})
|
||||||
|
scores = nlp_re2.evaluate([example])
|
||||||
|
assert "spans_some_other_key_f" in scores
|
||||||
|
|
||||||
|
|
||||||
def test_config_interpolation():
|
def test_config_interpolation():
|
||||||
config = Config().from_str(nlp_config_string, interpolate=False)
|
config = Config().from_str(nlp_config_string, interpolate=False)
|
||||||
assert config["corpora"]["train"]["path"] == "${paths.train}"
|
assert config["corpora"]["train"]["path"] == "${paths.train}"
|
||||||
|
|
|
@ -252,6 +252,10 @@ def test_minor_version(a1, a2, b1, b2, is_match):
|
||||||
{"training.batch_size": 128, "training.optimizer.learn_rate": 0.01},
|
{"training.batch_size": 128, "training.optimizer.learn_rate": 0.01},
|
||||||
{"training": {"batch_size": 128, "optimizer": {"learn_rate": 0.01}}},
|
{"training": {"batch_size": 128, "optimizer": {"learn_rate": 0.01}}},
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
{"attribute_ruler.scorer.@scorers": "spacy.tagger_scorer.v1"},
|
||||||
|
{"attribute_ruler": {"scorer": {"@scorers": "spacy.tagger_scorer.v1"}}},
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_dot_to_dict(dot_notation, expected):
|
def test_dot_to_dict(dot_notation, expected):
|
||||||
|
@ -260,6 +264,29 @@ def test_dot_to_dict(dot_notation, expected):
|
||||||
assert util.dict_to_dot(result) == dot_notation
|
assert util.dict_to_dot(result) == dot_notation
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"dot_notation,expected",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
{"token.pos": True, "token._.xyz": True},
|
||||||
|
{"token": {"pos": True, "_": {"xyz": True}}},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"training.batch_size": 128, "training.optimizer.learn_rate": 0.01},
|
||||||
|
{"training": {"batch_size": 128, "optimizer": {"learn_rate": 0.01}}},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"attribute_ruler.scorer": {"@scorers": "spacy.tagger_scorer.v1"}},
|
||||||
|
{"attribute_ruler": {"scorer": {"@scorers": "spacy.tagger_scorer.v1"}}},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_dot_to_dict_overrides(dot_notation, expected):
|
||||||
|
result = util.dot_to_dict(dot_notation)
|
||||||
|
assert result == expected
|
||||||
|
assert util.dict_to_dot(result, for_overrides=True) == dot_notation
|
||||||
|
|
||||||
|
|
||||||
def test_set_dot_to_object():
|
def test_set_dot_to_object():
|
||||||
config = {"foo": {"bar": 1, "baz": {"x": "y"}}, "test": {"a": {"b": "c"}}}
|
config = {"foo": {"bar": 1, "baz": {"x": "y"}}, "test": {"a": {"b": "c"}}}
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
|
|
|
@ -534,7 +534,7 @@ def load_model_from_path(
|
||||||
if not meta:
|
if not meta:
|
||||||
meta = get_model_meta(model_path)
|
meta = get_model_meta(model_path)
|
||||||
config_path = model_path / "config.cfg"
|
config_path = model_path / "config.cfg"
|
||||||
overrides = dict_to_dot(config)
|
overrides = dict_to_dot(config, for_overrides=True)
|
||||||
config = load_config(config_path, overrides=overrides)
|
config = load_config(config_path, overrides=overrides)
|
||||||
nlp = load_model_from_config(
|
nlp = load_model_from_config(
|
||||||
config,
|
config,
|
||||||
|
@ -1502,14 +1502,19 @@ def dot_to_dict(values: Dict[str, Any]) -> Dict[str, dict]:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def dict_to_dot(obj: Dict[str, dict]) -> Dict[str, Any]:
|
def dict_to_dot(obj: Dict[str, dict], *, for_overrides: bool = False) -> Dict[str, Any]:
|
||||||
"""Convert dot notation to a dict. For example: {"token": {"pos": True,
|
"""Convert dot notation to a dict. For example: {"token": {"pos": True,
|
||||||
"_": {"xyz": True }}} becomes {"token.pos": True, "token._.xyz": True}.
|
"_": {"xyz": True }}} becomes {"token.pos": True, "token._.xyz": True}.
|
||||||
|
|
||||||
values (Dict[str, dict]): The dict to convert.
|
obj (Dict[str, dict]): The dict to convert.
|
||||||
|
for_overrides (bool): Whether to enable special handling for registered
|
||||||
|
functions in overrides.
|
||||||
RETURNS (Dict[str, Any]): The key/value pairs.
|
RETURNS (Dict[str, Any]): The key/value pairs.
|
||||||
"""
|
"""
|
||||||
return {".".join(key): value for key, value in walk_dict(obj)}
|
return {
|
||||||
|
".".join(key): value
|
||||||
|
for key, value in walk_dict(obj, for_overrides=for_overrides)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def dot_to_object(config: Config, section: str):
|
def dot_to_object(config: Config, section: str):
|
||||||
|
@ -1551,13 +1556,20 @@ def set_dot_to_object(config: Config, section: str, value: Any) -> None:
|
||||||
|
|
||||||
|
|
||||||
def walk_dict(
|
def walk_dict(
|
||||||
node: Dict[str, Any], parent: List[str] = []
|
node: Dict[str, Any], parent: List[str] = [], *, for_overrides: bool = False
|
||||||
) -> Iterator[Tuple[List[str], Any]]:
|
) -> Iterator[Tuple[List[str], Any]]:
|
||||||
"""Walk a dict and yield the path and values of the leaves."""
|
"""Walk a dict and yield the path and values of the leaves.
|
||||||
|
|
||||||
|
for_overrides (bool): Whether to treat registered functions that start with
|
||||||
|
@ as final values rather than dicts to traverse.
|
||||||
|
"""
|
||||||
for key, value in node.items():
|
for key, value in node.items():
|
||||||
key_parent = [*parent, key]
|
key_parent = [*parent, key]
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict) and (
|
||||||
yield from walk_dict(value, key_parent)
|
not for_overrides
|
||||||
|
or not any(value_key.startswith("@") for value_key in value)
|
||||||
|
):
|
||||||
|
yield from walk_dict(value, key_parent, for_overrides=for_overrides)
|
||||||
else:
|
else:
|
||||||
yield (key_parent, value)
|
yield (key_parent, value)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user