Validate seed and gpu_allocator manually

This commit is contained in:
Adriane Boyd 2021-01-14 16:57:57 +01:00
parent 5fb8b7037a
commit 681a6195f7
2 changed files with 7 additions and 15 deletions

View File

@ -730,6 +730,8 @@ class Errors:
"DocBin (.spacy) format. If your data is in spaCy v2's JSON " "DocBin (.spacy) format. If your data is in spaCy v2's JSON "
"training format, convert it using `python -m spacy convert " "training format, convert it using `python -m spacy convert "
"file.json .`.") "file.json .`.")
E1015 = ("Can't initialize model from config: no {value} found. For more "
"information, run: python -m spacy debug config config.cfg")
# Deprecated model shortcuts, only used in errors and warnings # Deprecated model shortcuts, only used in errors and warnings

View File

@ -1,5 +1,4 @@
from typing import Union, Dict, Optional, Any, List, IO, TYPE_CHECKING from typing import Union, Dict, Optional, Any, List, IO, TYPE_CHECKING
from pydantic import BaseModel
from thinc.api import Config, fix_random_seed, set_gpu_allocator from thinc.api import Config, fix_random_seed, set_gpu_allocator
from thinc.api import ConfigValidationError from thinc.api import ConfigValidationError
from pathlib import Path from pathlib import Path
@ -13,7 +12,7 @@ import tqdm
from ..lookups import Lookups from ..lookups import Lookups
from ..vectors import Vectors from ..vectors import Vectors
from ..errors import Errors from ..errors import Errors
from ..schemas import ConfigSchemaInit, ConfigSchemaTraining from ..schemas import ConfigSchemaTraining
from ..util import registry, load_model_from_config, resolve_dot_names, logger from ..util import registry, load_model_from_config, resolve_dot_names, logger
from ..util import load_model, ensure_path, OOV_RANK, DEFAULT_OOV_PROB from ..util import load_model, ensure_path, OOV_RANK, DEFAULT_OOV_PROB
@ -24,9 +23,10 @@ if TYPE_CHECKING:
def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language": def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
raw_config = config raw_config = config
config = raw_config.interpolate() config = raw_config.interpolate()
# Validate config before accessing values if "seed" not in config["training"]:
_validate_config_block(config, "initialize", ConfigSchemaInit) raise ValueError(Errors.E1015.format(value="[training] seed"))
_validate_config_block(config, "training", ConfigSchemaTraining) if "gpu_allocator" not in config["training"]:
raise ValueError(Errors.E1015.format(value="[training] gpu_allocator"))
if config["training"]["seed"] is not None: if config["training"]["seed"] is not None:
fix_random_seed(config["training"]["seed"]) fix_random_seed(config["training"]["seed"])
allocator = config["training"]["gpu_allocator"] allocator = config["training"]["gpu_allocator"]
@ -273,13 +273,3 @@ def ensure_shape(lines):
length = len(captured) length = len(captured)
yield f"{length} {width}" yield f"{length} {width}"
yield from captured yield from captured
def _validate_config_block(config: Config, block: str, schema: BaseModel):
try:
registry.resolve(config[block], validate=True, schema=schema)
except ConfigValidationError as e:
title = f"Config validation error for [{block}]"
desc = "For more information run: python -m spacy debug config config.cfg"
err = ConfigValidationError.from_error(e, title=title, desc=desc)
raise err from None