From 94b3d5cb37ab3481a4fa7309757728c223f7c1e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Thu, 20 Apr 2023 16:48:37 +0200 Subject: [PATCH] Move `set_{gpu_allocator,seed}_from_config` to spacy.util --- spacy/training/initialize.py | 27 ++++++--------------------- spacy/training/loop.py | 17 ++++++----------- spacy/util.py | 20 ++++++++++++++++++++ 3 files changed, 32 insertions(+), 32 deletions(-) diff --git a/spacy/training/initialize.py b/spacy/training/initialize.py index 47706918b..479f5c9a7 100644 --- a/spacy/training/initialize.py +++ b/spacy/training/initialize.py @@ -1,6 +1,5 @@ from typing import Union, Dict, Optional, Any, IO, TYPE_CHECKING -from thinc.api import Config, fix_random_seed, set_gpu_allocator -from thinc.api import ConfigValidationError +from thinc.api import Config, ConfigValidationError from pathlib import Path import srsly import numpy @@ -19,6 +18,7 @@ from ..schemas import ConfigSchemaDistill, ConfigSchemaTraining from ..util import registry, load_model_from_config, resolve_dot_names, logger from ..util import load_model, ensure_path, get_sourced_components from ..util import OOV_RANK, DEFAULT_OOV_PROB +from ..util import set_gpu_allocator_from_config, set_seed_from_config if TYPE_CHECKING: from ..language import Language # noqa: F401 @@ -27,8 +27,8 @@ if TYPE_CHECKING: def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language": raw_config = config config = raw_config.interpolate() - _set_seed_from_config(config) - _set_gpu_allocator_from_config(config, use_gpu) + set_seed_from_config(config) + set_gpu_allocator_from_config(config, use_gpu) # Use original config here before it's resolved to functions sourced = get_sourced_components(config) nlp = load_model_from_config(raw_config, auto_fill=True) @@ -106,8 +106,8 @@ def init_nlp_student( """ raw_config = config config = raw_config.interpolate() - _set_seed_from_config(config) - _set_gpu_allocator_from_config(config, use_gpu) + set_seed_from_config(config) + set_gpu_allocator_from_config(config, use_gpu) # Use original config here before it's resolved to functions sourced = get_sourced_components(config) @@ -422,18 +422,3 @@ def ensure_shape(vectors_loc): yield from lines2 lines2.close() lines.close() - - -def _set_gpu_allocator_from_config(config: Config, use_gpu: int): - if "gpu_allocator" not in config["training"]: - raise ValueError(Errors.E1015.format(value="[training] gpu_allocator")) - allocator = config["training"]["gpu_allocator"] - if use_gpu >= 0 and allocator: - set_gpu_allocator(allocator) - - -def _set_seed_from_config(config: Config): - if "seed" not in config["training"]: - raise ValueError(Errors.E1015.format(value="[training] seed")) - if config["training"]["seed"] is not None: - fix_random_seed(config["training"]["seed"]) diff --git a/spacy/training/loop.py b/spacy/training/loop.py index 953fffaed..964ec9230 100644 --- a/spacy/training/loop.py +++ b/spacy/training/loop.py @@ -2,7 +2,7 @@ from typing import List, Callable, Tuple, Dict, Iterable, Union, Any, IO from typing import Optional, TYPE_CHECKING from pathlib import Path from timeit import default_timer as timer -from thinc.api import Optimizer, Config, constant, fix_random_seed, set_gpu_allocator +from thinc.api import Optimizer, Config, constant from wasabi import Printer import random import sys @@ -15,6 +15,7 @@ from ..errors import Errors from ..tokens.doc import Doc from .. import ty from ..util import resolve_dot_names, registry, logger +from ..util import set_gpu_allocator_from_config, set_seed_from_config if TYPE_CHECKING: from ..language import Language # noqa: F401 @@ -53,11 +54,8 @@ def distill( msg = Printer(no_print=True) # Create iterator, which yields out info after each optimization step. config = student.config.interpolate() - if config["training"]["seed"] is not None: - fix_random_seed(config["training"]["seed"]) - allocator = config["training"]["gpu_allocator"] - if use_gpu >= 0 and allocator: - set_gpu_allocator(allocator) + set_seed_from_config(config) + set_gpu_allocator_from_config(config, use_gpu) T = registry.resolve(config["training"], schema=ConfigSchemaTraining) D = registry.resolve(config["distillation"], schema=ConfigSchemaDistill) dot_names = [D["corpus"], T["dev_corpus"]] @@ -175,11 +173,8 @@ def train( msg = Printer(no_print=True) # Create iterator, which yields out info after each optimization step. config = nlp.config.interpolate() - if config["training"]["seed"] is not None: - fix_random_seed(config["training"]["seed"]) - allocator = config["training"]["gpu_allocator"] - if use_gpu >= 0 and allocator: - set_gpu_allocator(allocator) + set_seed_from_config(config) + set_gpu_allocator_from_config(config, use_gpu) T = registry.resolve(config["training"], schema=ConfigSchemaTraining) dot_names = [T["train_corpus"], T["dev_corpus"]] train_corpus, dev_corpus = resolve_dot_names(config, dot_names) diff --git a/spacy/util.py b/spacy/util.py index 1ce869152..f0d5b3f58 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -11,6 +11,7 @@ from pathlib import Path import thinc from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer from thinc.api import ConfigValidationError, Model, constant as constant_schedule +from thinc.api import fix_random_seed, set_gpu_allocator import functools import itertools import numpy @@ -1790,3 +1791,22 @@ def find_available_port(start: int, host: str, auto_select: bool = False) -> int # if we get here, the port changed warnings.warn(Warnings.W124.format(host=host, port=start, serve_port=port)) return port + + +def set_gpu_allocator_from_config(config: Config, use_gpu: int): + """Change the global GPU allocator based to the value in + the configuration.""" + if "gpu_allocator" not in config["training"]: + raise ValueError(Errors.E1015.format(value="[training] gpu_allocator")) + allocator = config["training"]["gpu_allocator"] + if use_gpu >= 0 and allocator: + set_gpu_allocator(allocator) + + +def set_seed_from_config(config: Config): + """Set the random number generator seed to the value in + the configuration.""" + if "seed" not in config["training"]: + raise ValueError(Errors.E1015.format(value="[training] seed")) + if config["training"]["seed"] is not None: + fix_random_seed(config["training"]["seed"])