Validate section refs in debug config

This commit is contained in:
Ines Montani 2020-09-22 12:24:39 +02:00
parent db7126ead9
commit 5e3b796b12
2 changed files with 39 additions and 3 deletions

View File

@ -2,7 +2,7 @@ from typing import Optional, Dict, Any, Union, List
from pathlib import Path from pathlib import Path
from wasabi import msg, table from wasabi import msg, table
from thinc.api import Config from thinc.api import Config
from thinc.config import VARIABLE_RE from thinc.config import VARIABLE_RE, ConfigValidationError
import typer import typer
from ._util import Arg, Opt, show_validation_error, parse_config_overrides from ._util import Arg, Opt, show_validation_error, parse_config_overrides
@ -51,7 +51,10 @@ def debug_config(
msg.divider("Config validation") msg.divider("Config validation")
with show_validation_error(config_path): with show_validation_error(config_path):
config = util.load_config(config_path, overrides=overrides) config = util.load_config(config_path, overrides=overrides)
nlp, _ = util.load_model_from_config(config) nlp, resolved = util.load_model_from_config(config)
# Use the resolved config here in case user has one function returning
# a dict of corpora etc.
check_section_refs(resolved, ["training.dev_corpus", "training.train_corpus"])
msg.good("Config is valid") msg.good("Config is valid")
if show_vars: if show_vars:
variables = get_variables(config) variables = get_variables(config)
@ -93,3 +96,23 @@ def get_variables(config: Config) -> Dict[str, Any]:
value = util.dot_to_object(config, path) value = util.dot_to_object(config, path)
result[variable] = repr(value) result[variable] = repr(value)
return result return result
def check_section_refs(config: Config, fields: List[str]) -> None:
"""Validate fields in the config that refer to other sections or values
(e.g. in the corpora) and make sure that those references exist.
"""
errors = []
for field in fields:
# If the field doesn't exist in the config, we ignore it
try:
value = util.dot_to_object(config, field)
except KeyError:
continue
try:
util.dot_to_object(config, value)
except KeyError:
msg = f"not a valid section reference: {value}"
errors.append({"loc": field.split("."), "msg": msg})
if errors:
raise ConfigValidationError(config, errors)

View File

@ -7,7 +7,8 @@ from spacy.cli.init_config import init_config, RECOMMENDATIONS
from spacy.cli._util import validate_project_commands, parse_config_overrides from spacy.cli._util import validate_project_commands, parse_config_overrides
from spacy.cli._util import load_project_config, substitute_project_variables from spacy.cli._util import load_project_config, substitute_project_variables
from spacy.cli._util import string_to_list, OVERRIDES_ENV_VAR from spacy.cli._util import string_to_list, OVERRIDES_ENV_VAR
from thinc.config import ConfigValidationError from spacy.cli.debug_config import check_section_refs
from thinc.config import ConfigValidationError, Config
import srsly import srsly
import os import os
@ -413,3 +414,15 @@ def test_string_to_list(value):
def test_string_to_list_intify(value): def test_string_to_list_intify(value):
assert string_to_list(value, intify=False) == ["1", "2", "3"] assert string_to_list(value, intify=False) == ["1", "2", "3"]
assert string_to_list(value, intify=True) == [1, 2, 3] assert string_to_list(value, intify=True) == [1, 2, 3]
def test_check_section_refs():
config = {"a": {"b": {"c": "a.d.e"}, "d": {"e": 1}}, "f": {"g": "d.f"}}
config = Config(config)
# Valid section reference
check_section_refs(config, ["a.b.c"])
# Section that doesn't exist in this config
check_section_refs(config, ["x.y.z"])
# Invalid section reference
with pytest.raises(ConfigValidationError):
check_section_refs(config, ["a.b.c", "f.g"])