mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Validate seed and gpu_allocator manually
This commit is contained in:
parent
5fb8b7037a
commit
681a6195f7
|
@ -730,6 +730,8 @@ class Errors:
|
|||
"DocBin (.spacy) format. If your data is in spaCy v2's JSON "
|
||||
"training format, convert it using `python -m spacy convert "
|
||||
"file.json .`.")
|
||||
E1015 = ("Can't initialize model from config: no {value} found. For more "
|
||||
"information, run: python -m spacy debug config config.cfg")
|
||||
|
||||
|
||||
# Deprecated model shortcuts, only used in errors and warnings
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from typing import Union, Dict, Optional, Any, List, IO, TYPE_CHECKING
|
||||
from pydantic import BaseModel
|
||||
from thinc.api import Config, fix_random_seed, set_gpu_allocator
|
||||
from thinc.api import ConfigValidationError
|
||||
from pathlib import Path
|
||||
|
@ -13,7 +12,7 @@ import tqdm
|
|||
from ..lookups import Lookups
|
||||
from ..vectors import Vectors
|
||||
from ..errors import Errors
|
||||
from ..schemas import ConfigSchemaInit, ConfigSchemaTraining
|
||||
from ..schemas import ConfigSchemaTraining
|
||||
from ..util import registry, load_model_from_config, resolve_dot_names, logger
|
||||
from ..util import load_model, ensure_path, OOV_RANK, DEFAULT_OOV_PROB
|
||||
|
||||
|
@ -24,9 +23,10 @@ if TYPE_CHECKING:
|
|||
def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
||||
raw_config = config
|
||||
config = raw_config.interpolate()
|
||||
# Validate config before accessing values
|
||||
_validate_config_block(config, "initialize", ConfigSchemaInit)
|
||||
_validate_config_block(config, "training", ConfigSchemaTraining)
|
||||
if "seed" not in config["training"]:
|
||||
raise ValueError(Errors.E1015.format(value="[training] seed"))
|
||||
if "gpu_allocator" not in config["training"]:
|
||||
raise ValueError(Errors.E1015.format(value="[training] gpu_allocator"))
|
||||
if config["training"]["seed"] is not None:
|
||||
fix_random_seed(config["training"]["seed"])
|
||||
allocator = config["training"]["gpu_allocator"]
|
||||
|
@ -273,13 +273,3 @@ def ensure_shape(lines):
|
|||
length = len(captured)
|
||||
yield f"{length} {width}"
|
||||
yield from captured
|
||||
|
||||
|
||||
def _validate_config_block(config: Config, block: str, schema: BaseModel):
|
||||
try:
|
||||
registry.resolve(config[block], validate=True, schema=schema)
|
||||
except ConfigValidationError as e:
|
||||
title = f"Config validation error for [{block}]"
|
||||
desc = "For more information run: python -m spacy debug config config.cfg"
|
||||
err = ConfigValidationError.from_error(e, title=title, desc=desc)
|
||||
raise err from None
|
||||
|
|
Loading…
Reference in New Issue
Block a user