Validate seed and gpu_allocator manually

This commit is contained in:
Adriane Boyd 2021-01-14 16:57:57 +01:00
parent 5fb8b7037a
commit 681a6195f7
2 changed files with 7 additions and 15 deletions

View File

@ -730,6 +730,8 @@ class Errors:
"DocBin (.spacy) format. If your data is in spaCy v2's JSON "
"training format, convert it using `python -m spacy convert "
"file.json .`.")
E1015 = ("Can't initialize model from config: no {value} found. For more "
"information, run: python -m spacy debug config config.cfg")
# Deprecated model shortcuts, only used in errors and warnings

View File

@ -1,5 +1,4 @@
from typing import Union, Dict, Optional, Any, List, IO, TYPE_CHECKING
from pydantic import BaseModel
from thinc.api import Config, fix_random_seed, set_gpu_allocator
from thinc.api import ConfigValidationError
from pathlib import Path
@ -13,7 +12,7 @@ import tqdm
from ..lookups import Lookups
from ..vectors import Vectors
from ..errors import Errors
from ..schemas import ConfigSchemaInit, ConfigSchemaTraining
from ..schemas import ConfigSchemaTraining
from ..util import registry, load_model_from_config, resolve_dot_names, logger
from ..util import load_model, ensure_path, OOV_RANK, DEFAULT_OOV_PROB
@ -24,9 +23,10 @@ if TYPE_CHECKING:
def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
raw_config = config
config = raw_config.interpolate()
# Validate config before accessing values
_validate_config_block(config, "initialize", ConfigSchemaInit)
_validate_config_block(config, "training", ConfigSchemaTraining)
if "seed" not in config["training"]:
raise ValueError(Errors.E1015.format(value="[training] seed"))
if "gpu_allocator" not in config["training"]:
raise ValueError(Errors.E1015.format(value="[training] gpu_allocator"))
if config["training"]["seed"] is not None:
fix_random_seed(config["training"]["seed"])
allocator = config["training"]["gpu_allocator"]
@ -273,13 +273,3 @@ def ensure_shape(lines):
length = len(captured)
yield f"{length} {width}"
yield from captured
def _validate_config_block(config: Config, block: str, schema: BaseModel):
try:
registry.resolve(config[block], validate=True, schema=schema)
except ConfigValidationError as e:
title = f"Config validation error for [{block}]"
desc = "For more information run: python -m spacy debug config config.cfg"
err = ConfigValidationError.from_error(e, title=title, desc=desc)
raise err from None