Expand initialize/training config validation

Validate both `[initialize]` and `[training]` in `debug data` and
`nlp.initialize()` with separate config validation error blocks that
indicate which block of the config is being validated.
This commit is contained in:
Adriane Boyd 2021-01-12 17:17:00 +01:00
parent ad43cbb042
commit 5fb8b7037a
2 changed files with 21 additions and 2 deletions

View File

@ -7,7 +7,7 @@ import typer
from ._util import Arg, Opt, show_validation_error, parse_config_overrides
from ._util import import_code, debug_cli
from ..schemas import ConfigSchemaTraining
from ..schemas import ConfigSchemaInit, ConfigSchemaTraining
from ..util import registry
from .. import util
@ -55,6 +55,11 @@ def debug_config(
config = util.load_config(config_path, overrides=overrides)
nlp = util.load_model_from_config(config)
config = nlp.config.interpolate()
msg.divider("Config validation for [initialize]")
with show_validation_error(config_path):
T = registry.resolve(config["initialize"], schema=ConfigSchemaInit)
msg.divider("Config validation for [training]")
with show_validation_error(config_path):
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
dot_names = [T["train_corpus"], T["dev_corpus"]]
util.resolve_dot_names(config, dot_names)

View File

@ -1,4 +1,5 @@
from typing import Union, Dict, Optional, Any, List, IO, TYPE_CHECKING
from pydantic import BaseModel
from thinc.api import Config, fix_random_seed, set_gpu_allocator
from thinc.api import ConfigValidationError
from pathlib import Path
@ -12,7 +13,7 @@ import tqdm
from ..lookups import Lookups
from ..vectors import Vectors
from ..errors import Errors
from ..schemas import ConfigSchemaTraining
from ..schemas import ConfigSchemaInit, ConfigSchemaTraining
from ..util import registry, load_model_from_config, resolve_dot_names, logger
from ..util import load_model, ensure_path, OOV_RANK, DEFAULT_OOV_PROB
@ -23,6 +24,9 @@ if TYPE_CHECKING:
def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
raw_config = config
config = raw_config.interpolate()
# Validate config before accessing values
_validate_config_block(config, "initialize", ConfigSchemaInit)
_validate_config_block(config, "training", ConfigSchemaTraining)
if config["training"]["seed"] is not None:
fix_random_seed(config["training"]["seed"])
allocator = config["training"]["gpu_allocator"]
@ -269,3 +273,13 @@ def ensure_shape(lines):
length = len(captured)
yield f"{length} {width}"
yield from captured
def _validate_config_block(config: Config, block: str, schema: BaseModel):
try:
registry.resolve(config[block], validate=True, schema=schema)
except ConfigValidationError as e:
title = f"Config validation error for [{block}]"
desc = "For more information run: python -m spacy debug config config.cfg"
err = ConfigValidationError.from_error(e, title=title, desc=desc)
raise err from None