Expand initialize/training config validation

Validate both `[initialize]` and `[training]` in `debug data` and
`nlp.initialize()` with separate config validation error blocks that
indicate which block of the config is being validated.
This commit is contained in:
Adriane Boyd 2021-01-12 17:17:00 +01:00
parent ad43cbb042
commit 5fb8b7037a
2 changed files with 21 additions and 2 deletions

View File

@ -7,7 +7,7 @@ import typer
from ._util import Arg, Opt, show_validation_error, parse_config_overrides from ._util import Arg, Opt, show_validation_error, parse_config_overrides
from ._util import import_code, debug_cli from ._util import import_code, debug_cli
from ..schemas import ConfigSchemaTraining from ..schemas import ConfigSchemaInit, ConfigSchemaTraining
from ..util import registry from ..util import registry
from .. import util from .. import util
@ -55,6 +55,11 @@ def debug_config(
config = util.load_config(config_path, overrides=overrides) config = util.load_config(config_path, overrides=overrides)
nlp = util.load_model_from_config(config) nlp = util.load_model_from_config(config)
config = nlp.config.interpolate() config = nlp.config.interpolate()
msg.divider("Config validation for [initialize]")
with show_validation_error(config_path):
T = registry.resolve(config["initialize"], schema=ConfigSchemaInit)
msg.divider("Config validation for [training]")
with show_validation_error(config_path):
T = registry.resolve(config["training"], schema=ConfigSchemaTraining) T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
dot_names = [T["train_corpus"], T["dev_corpus"]] dot_names = [T["train_corpus"], T["dev_corpus"]]
util.resolve_dot_names(config, dot_names) util.resolve_dot_names(config, dot_names)

View File

@ -1,4 +1,5 @@
from typing import Union, Dict, Optional, Any, List, IO, TYPE_CHECKING from typing import Union, Dict, Optional, Any, List, IO, TYPE_CHECKING
from pydantic import BaseModel
from thinc.api import Config, fix_random_seed, set_gpu_allocator from thinc.api import Config, fix_random_seed, set_gpu_allocator
from thinc.api import ConfigValidationError from thinc.api import ConfigValidationError
from pathlib import Path from pathlib import Path
@ -12,7 +13,7 @@ import tqdm
from ..lookups import Lookups from ..lookups import Lookups
from ..vectors import Vectors from ..vectors import Vectors
from ..errors import Errors from ..errors import Errors
from ..schemas import ConfigSchemaTraining from ..schemas import ConfigSchemaInit, ConfigSchemaTraining
from ..util import registry, load_model_from_config, resolve_dot_names, logger from ..util import registry, load_model_from_config, resolve_dot_names, logger
from ..util import load_model, ensure_path, OOV_RANK, DEFAULT_OOV_PROB from ..util import load_model, ensure_path, OOV_RANK, DEFAULT_OOV_PROB
@ -23,6 +24,9 @@ if TYPE_CHECKING:
def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language": def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
raw_config = config raw_config = config
config = raw_config.interpolate() config = raw_config.interpolate()
# Validate config before accessing values
_validate_config_block(config, "initialize", ConfigSchemaInit)
_validate_config_block(config, "training", ConfigSchemaTraining)
if config["training"]["seed"] is not None: if config["training"]["seed"] is not None:
fix_random_seed(config["training"]["seed"]) fix_random_seed(config["training"]["seed"])
allocator = config["training"]["gpu_allocator"] allocator = config["training"]["gpu_allocator"]
@ -269,3 +273,13 @@ def ensure_shape(lines):
length = len(captured) length = len(captured)
yield f"{length} {width}" yield f"{length} {width}"
yield from captured yield from captured
def _validate_config_block(config: Config, block: str, schema: BaseModel):
try:
registry.resolve(config[block], validate=True, schema=schema)
except ConfigValidationError as e:
title = f"Config validation error for [{block}]"
desc = "For more information run: python -m spacy debug config config.cfg"
err = ConfigValidationError.from_error(e, title=title, desc=desc)
raise err from None