mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-12 01:02:23 +03:00
Move set_{gpu_allocator,seed}_from_config
to spacy.util
This commit is contained in:
parent
c355367f87
commit
94b3d5cb37
|
@ -1,6 +1,5 @@
|
|||
from typing import Union, Dict, Optional, Any, IO, TYPE_CHECKING
|
||||
from thinc.api import Config, fix_random_seed, set_gpu_allocator
|
||||
from thinc.api import ConfigValidationError
|
||||
from thinc.api import Config, ConfigValidationError
|
||||
from pathlib import Path
|
||||
import srsly
|
||||
import numpy
|
||||
|
@ -19,6 +18,7 @@ from ..schemas import ConfigSchemaDistill, ConfigSchemaTraining
|
|||
from ..util import registry, load_model_from_config, resolve_dot_names, logger
|
||||
from ..util import load_model, ensure_path, get_sourced_components
|
||||
from ..util import OOV_RANK, DEFAULT_OOV_PROB
|
||||
from ..util import set_gpu_allocator_from_config, set_seed_from_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..language import Language # noqa: F401
|
||||
|
@ -27,8 +27,8 @@ if TYPE_CHECKING:
|
|||
def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
||||
raw_config = config
|
||||
config = raw_config.interpolate()
|
||||
_set_seed_from_config(config)
|
||||
_set_gpu_allocator_from_config(config, use_gpu)
|
||||
set_seed_from_config(config)
|
||||
set_gpu_allocator_from_config(config, use_gpu)
|
||||
# Use original config here before it's resolved to functions
|
||||
sourced = get_sourced_components(config)
|
||||
nlp = load_model_from_config(raw_config, auto_fill=True)
|
||||
|
@ -106,8 +106,8 @@ def init_nlp_student(
|
|||
"""
|
||||
raw_config = config
|
||||
config = raw_config.interpolate()
|
||||
_set_seed_from_config(config)
|
||||
_set_gpu_allocator_from_config(config, use_gpu)
|
||||
set_seed_from_config(config)
|
||||
set_gpu_allocator_from_config(config, use_gpu)
|
||||
|
||||
# Use original config here before it's resolved to functions
|
||||
sourced = get_sourced_components(config)
|
||||
|
@ -422,18 +422,3 @@ def ensure_shape(vectors_loc):
|
|||
yield from lines2
|
||||
lines2.close()
|
||||
lines.close()
|
||||
|
||||
|
||||
def _set_gpu_allocator_from_config(config: Config, use_gpu: int):
|
||||
if "gpu_allocator" not in config["training"]:
|
||||
raise ValueError(Errors.E1015.format(value="[training] gpu_allocator"))
|
||||
allocator = config["training"]["gpu_allocator"]
|
||||
if use_gpu >= 0 and allocator:
|
||||
set_gpu_allocator(allocator)
|
||||
|
||||
|
||||
def _set_seed_from_config(config: Config):
|
||||
if "seed" not in config["training"]:
|
||||
raise ValueError(Errors.E1015.format(value="[training] seed"))
|
||||
if config["training"]["seed"] is not None:
|
||||
fix_random_seed(config["training"]["seed"])
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import List, Callable, Tuple, Dict, Iterable, Union, Any, IO
|
|||
from typing import Optional, TYPE_CHECKING
|
||||
from pathlib import Path
|
||||
from timeit import default_timer as timer
|
||||
from thinc.api import Optimizer, Config, constant, fix_random_seed, set_gpu_allocator
|
||||
from thinc.api import Optimizer, Config, constant
|
||||
from wasabi import Printer
|
||||
import random
|
||||
import sys
|
||||
|
@ -15,6 +15,7 @@ from ..errors import Errors
|
|||
from ..tokens.doc import Doc
|
||||
from .. import ty
|
||||
from ..util import resolve_dot_names, registry, logger
|
||||
from ..util import set_gpu_allocator_from_config, set_seed_from_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..language import Language # noqa: F401
|
||||
|
@ -53,11 +54,8 @@ def distill(
|
|||
msg = Printer(no_print=True)
|
||||
# Create iterator, which yields out info after each optimization step.
|
||||
config = student.config.interpolate()
|
||||
if config["training"]["seed"] is not None:
|
||||
fix_random_seed(config["training"]["seed"])
|
||||
allocator = config["training"]["gpu_allocator"]
|
||||
if use_gpu >= 0 and allocator:
|
||||
set_gpu_allocator(allocator)
|
||||
set_seed_from_config(config)
|
||||
set_gpu_allocator_from_config(config, use_gpu)
|
||||
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
||||
D = registry.resolve(config["distillation"], schema=ConfigSchemaDistill)
|
||||
dot_names = [D["corpus"], T["dev_corpus"]]
|
||||
|
@ -175,11 +173,8 @@ def train(
|
|||
msg = Printer(no_print=True)
|
||||
# Create iterator, which yields out info after each optimization step.
|
||||
config = nlp.config.interpolate()
|
||||
if config["training"]["seed"] is not None:
|
||||
fix_random_seed(config["training"]["seed"])
|
||||
allocator = config["training"]["gpu_allocator"]
|
||||
if use_gpu >= 0 and allocator:
|
||||
set_gpu_allocator(allocator)
|
||||
set_seed_from_config(config)
|
||||
set_gpu_allocator_from_config(config, use_gpu)
|
||||
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
||||
dot_names = [T["train_corpus"], T["dev_corpus"]]
|
||||
train_corpus, dev_corpus = resolve_dot_names(config, dot_names)
|
||||
|
|
|
@ -11,6 +11,7 @@ from pathlib import Path
|
|||
import thinc
|
||||
from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer
|
||||
from thinc.api import ConfigValidationError, Model, constant as constant_schedule
|
||||
from thinc.api import fix_random_seed, set_gpu_allocator
|
||||
import functools
|
||||
import itertools
|
||||
import numpy
|
||||
|
@ -1790,3 +1791,22 @@ def find_available_port(start: int, host: str, auto_select: bool = False) -> int
|
|||
# if we get here, the port changed
|
||||
warnings.warn(Warnings.W124.format(host=host, port=start, serve_port=port))
|
||||
return port
|
||||
|
||||
|
||||
def set_gpu_allocator_from_config(config: Config, use_gpu: int):
|
||||
"""Change the global GPU allocator based to the value in
|
||||
the configuration."""
|
||||
if "gpu_allocator" not in config["training"]:
|
||||
raise ValueError(Errors.E1015.format(value="[training] gpu_allocator"))
|
||||
allocator = config["training"]["gpu_allocator"]
|
||||
if use_gpu >= 0 and allocator:
|
||||
set_gpu_allocator(allocator)
|
||||
|
||||
|
||||
def set_seed_from_config(config: Config):
|
||||
"""Set the random number generator seed to the value in
|
||||
the configuration."""
|
||||
if "seed" not in config["training"]:
|
||||
raise ValueError(Errors.E1015.format(value="[training] seed"))
|
||||
if config["training"]["seed"] is not None:
|
||||
fix_random_seed(config["training"]["seed"])
|
||||
|
|
Loading…
Reference in New Issue
Block a user