Introducing the gpu_allocator (#6091)

* rename 'use_pytorch_for_gpu_memory' to 'gpu_allocator'

* --code instead of --code-path

* update documentation

* avoid querying the "system" section directly

* add explanation of gpu_allocator to TF/PyTorch section in docs

* fix typo

* fix typo 2

* use set_gpu_allocator from thinc 8.0.0a34

* default null instead of empty string
This commit is contained in:
Sofie Van Landeghem 2020-09-19 01:17:02 +02:00 committed by GitHub
parent 0406200a1e
commit 39872de1f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 54 additions and 31 deletions

View File

@ -6,7 +6,7 @@ requires = [
"cymem>=2.0.2,<2.1.0",
"preshed>=3.0.2,<3.1.0",
"murmurhash>=0.28.0,<1.1.0",
"thinc>=8.0.0a33,<8.0.0a40",
"thinc>=8.0.0a34,<8.0.0a40",
"blis>=0.4.0,<0.5.0",
"pytokenizations",
"pathy"

View File

@ -1,7 +1,7 @@
# Our libraries
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
thinc>=8.0.0a33,<8.0.0a40
thinc>=8.0.0a34,<8.0.0a40
blis>=0.4.0,<0.5.0
ml_datasets==0.2.0a0
murmurhash>=0.28.0,<1.1.0

View File

@ -34,13 +34,13 @@ setup_requires =
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
murmurhash>=0.28.0,<1.1.0
thinc>=8.0.0a33,<8.0.0a40
thinc>=8.0.0a34,<8.0.0a40
install_requires =
# Our libraries
murmurhash>=0.28.0,<1.1.0
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
thinc>=8.0.0a33,<8.0.0a40
thinc>=8.0.0a34,<8.0.0a40
blis>=0.4.0,<0.5.0
wasabi>=0.8.0,<1.1.0
srsly>=2.1.0,<3.0.0

View File

@ -2,7 +2,7 @@ from typing import Dict, Any, Optional
from pathlib import Path
from wasabi import msg
from thinc.api import require_gpu, fix_random_seed, set_dropout_rate, Adam
from thinc.api import Model, data_validation
from thinc.api import Model, data_validation, set_gpu_allocator
import typer
from ._util import Arg, Opt, debug_cli, show_validation_error
@ -53,7 +53,12 @@ def debug_model_cli(
}
config_overrides = parse_config_overrides(ctx.args)
with show_validation_error(config_path):
config = util.load_config(config_path, overrides=config_overrides)
config = util.load_config(
config_path, overrides=config_overrides, interpolate=True
)
allocator = config["training"]["gpu_allocator"]
if use_gpu >= 0 and allocator:
set_gpu_allocator(allocator)
nlp, config = util.load_model_from_config(config_path)
seed = config["training"]["seed"]
if seed is not None:

View File

@ -4,10 +4,9 @@ import time
import re
from collections import Counter
from pathlib import Path
from thinc.api import Config
from thinc.api import use_pytorch_for_gpu_memory, require_gpu
from thinc.api import require_gpu, set_gpu_allocator
from thinc.api import set_dropout_rate, to_categorical, fix_random_seed
from thinc.api import CosineDistance, L2Distance
from thinc.api import Config, CosineDistance, L2Distance
from wasabi import msg
import srsly
from functools import partial
@ -32,7 +31,7 @@ def pretrain_cli(
ctx: typer.Context, # This is only used to read additional arguments
config_path: Path = Arg(..., help="Path to config file", exists=True, dir_okay=False),
output_dir: Path = Arg(..., help="Directory to write weights to on each epoch"),
code_path: Optional[Path] = Opt(None, "--code-path", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
code_path: Optional[Path] = Opt(None, "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
resume_path: Optional[Path] = Opt(None, "--resume-path", "-r", help="Path to pretrained weights from which to resume pretraining"),
epoch_resume: Optional[int] = Opt(None, "--epoch-resume", "-er", help="The epoch to resume counting from when using --resume-path. Prevents unintended overwriting of existing weight files."),
use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU"),
@ -99,10 +98,12 @@ def pretrain(
epoch_resume: Optional[int] = None,
use_gpu: int = -1,
):
if config["system"].get("seed") is not None:
fix_random_seed(config["system"]["seed"])
if use_gpu >= 0 and config["system"].get("use_pytorch_for_gpu_memory"):
use_pytorch_for_gpu_memory()
if config["training"]["seed"] is not None:
fix_random_seed(config["training"]["seed"])
allocator = config["training"]["gpu_allocator"]
if use_gpu >= 0 and allocator:
set_gpu_allocator(allocator)
nlp, config = util.load_model_from_config(config)
P_cfg = config["pretraining"]
corpus = dot_to_object(config, P_cfg["corpus"])

View File

@ -8,7 +8,7 @@ train = ""
dev = ""
[system]
use_pytorch_for_gpu_memory = {{ "true" if use_transformer else "false" }}
gpu_allocator = {{ "pytorch" if use_transformer else "" }}
[nlp]
lang = "{{ lang }}"

View File

@ -6,8 +6,7 @@ from pathlib import Path
from wasabi import msg
import thinc
import thinc.schedules
from thinc.api import use_pytorch_for_gpu_memory, require_gpu, fix_random_seed
from thinc.api import Config, Optimizer
from thinc.api import Config, Optimizer, require_gpu, fix_random_seed, set_gpu_allocator
import random
import typer
import logging
@ -29,7 +28,7 @@ def train_cli(
ctx: typer.Context, # This is only used to read additional arguments
config_path: Path = Arg(..., help="Path to config file", exists=True),
output_path: Optional[Path] = Opt(None, "--output", "--output-path", "-o", help="Output directory to store trained pipeline in"),
code_path: Optional[Path] = Opt(None, "--code-path", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
code_path: Optional[Path] = Opt(None, "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"),
use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU"),
resume: bool = Opt(False, "--resume", "-R", help="Resume training"),
@ -79,11 +78,11 @@ def train(
config = util.load_config(
config_path, overrides=config_overrides, interpolate=True
)
if config.get("training", {}).get("seed") is not None:
if config["training"]["seed"] is not None:
fix_random_seed(config["training"]["seed"])
if config.get("system", {}).get("use_pytorch_for_gpu_memory"):
# It feels kind of weird to not have a default for this.
use_pytorch_for_gpu_memory()
allocator = config["training"]["gpu_allocator"]
if use_gpu >= 0 and allocator:
set_gpu_allocator(allocator)
# Use original config here before it's resolved to functions
sourced_components = get_sourced_components(config)
with show_validation_error(config_path):

View File

@ -6,7 +6,7 @@ init_tok2vec = null
[system]
seed = 0
use_pytorch_for_gpu_memory = false
gpu_allocator = null
[nlp]
lang = null
@ -52,6 +52,7 @@ limit = 0
# Training hyper-parameters and additional features.
[training]
seed = ${system.seed}
gpu_allocator = ${system.gpu_allocator}
dropout = 0.1
accumulate_gradient = 1
# Extra resources for transfer-learning or pseudo-rehearsal
@ -75,7 +76,6 @@ train_corpus = "corpora.train"
[training.logger]
@loggers = "spacy.ConsoleLogger.v1"
[training.batcher]
@batchers = "spacy.batch_by_words.v1"
discard_oversize = false

View File

@ -207,6 +207,7 @@ class ConfigSchemaTraining(BaseModel):
max_steps: StrictInt = Field(..., title="Maximum number of update steps to train for")
eval_frequency: StrictInt = Field(..., title="How often to evaluate during training (steps)")
seed: Optional[StrictInt] = Field(..., title="Random seed")
gpu_allocator: Optional[StrictStr] = Field(..., title="Memory allocator when running on GPU")
accumulate_gradient: StrictInt = Field(..., title="Whether to divide the batch up into substeps")
score_weights: Dict[StrictStr, Union[StrictFloat, StrictInt]] = Field(..., title="Scores to report and their weights for selecting final model")
init_tok2vec: Optional[StrictStr] = Field(..., title="Path to pretrained tok2vec weights")

View File

@ -763,6 +763,7 @@ $ python -m spacy train [config_path] [--output] [--code] [--verbose] [overrides
| `--output`, `-o` | Directory to store trained pipeline in. Will be created if it doesn't exist. ~~Optional[Path] \(positional)~~ |
| `--code`, `-c` | Path to Python file with additional code to be imported. Allows [registering custom functions](/usage/training#custom-functions) for new architectures. ~~Optional[Path] \(option)~~ |
| `--verbose`, `-V` | Show more detailed messages during training. ~~bool (flag)~~ |
| `--gpu-id`, `-g` | GPU ID or `-1` for CPU. Defaults to `-1`. ~~int (option)~~ |
| `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ |
| overrides | Config parameters to override. Should be options starting with `--` that correspond to the config section and value to override, e.g. `--paths.train ./train.spacy`. ~~Any (option/flag)~~ |
| **CREATES** | The final trained pipeline and the best trained pipeline. |
@ -798,11 +799,12 @@ $ python -m spacy pretrain [config_path] [output_dir] [--code] [--resume-path] [
| Name | Description |
| ----------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `output_dir` | Directory to save binary weights to on each epoch. ~~Path (positional)~~ |
| `config_path` | Path to [training config](/api/data-formats#config) file containing all settings and hyperparameters. ~~Path (positional)~~ |
| `output_dir` | Directory to save binary weights to on each epoch. ~~Path (positional)~~ |
| `--code`, `-c` | Path to Python file with additional code to be imported. Allows [registering custom functions](/usage/training#custom-functions) for new architectures. ~~Optional[Path] \(option)~~ |
| `--resume-path`, `-r` | Path to pretrained weights from which to resume pretraining. ~~Optional[Path] \(option)~~ |
| `--epoch-resume`, `-er` | The epoch to resume counting from when using `--resume-path`. Prevents unintended overwriting of existing weight files. ~~Optional[int] \(option)~~ |
| `--gpu-id`, `-g` | GPU ID or `-1` for CPU. Defaults to `-1`. ~~int (option)~~ |
| `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ |
| overrides | Config parameters to override. Should be options starting with `--` that correspond to the config section and value to override, e.g. `--training.dropout 0.2`. ~~Any (option/flag)~~ |
| **CREATES** | The pretrained weights that can be used to initialize `spacy train`. |

View File

@ -189,6 +189,7 @@ process that are used when you run [`spacy train`](/api/cli#train).
| `dev_corpus` | Dot notation of the config location defining the dev corpus. Defaults to `corpora.dev`. ~~str~~ |
| `dropout` | The dropout rate. Defaults to `0.1`. ~~float~~ |
| `eval_frequency` | How often to evaluate during training (steps). Defaults to `200`. ~~int~~ |
| `gpu_allocator` | Library for cupy to route GPU memory allocation to. Can be "pytorch" or "tensorflow". Defaults to variable `${system.gpu_allocator}`. ~~str~~ |
| `frozen_components` | Pipeline component names that are "frozen" and shouldn't be updated during training. See [here](/usage/training#config-components) for details. Defaults to `[]`. ~~List[str]~~ |
| `init_tok2vec` | Optional path to pretrained tok2vec weights created with [`spacy pretrain`](/api/cli#pretrain). Defaults to variable `${paths.init_tok2vec}`. ~~Optional[str]~~ |
| `max_epochs` | Maximum number of epochs to train for. Defaults to `0`. ~~int~~ |

View File

@ -145,9 +145,10 @@ pipelines.
> nlp = spacy.load("en_core_web_sm")
> ```
| Name | Description |
| ----------- | --------------------------------------- |
| **RETURNS** | Whether the GPU was activated. ~~bool~~ |
| Name | Description |
| ----------- | ------------------------------------------------ |
| `gpu_id` | Device index to select. Defaults to `0`. ~~int~~ |
| **RETURNS** | Whether the GPU was activated. ~~bool~~ |
### spacy.require_gpu {#spacy.require_gpu tag="function" new="2.0.14"}
@ -164,9 +165,10 @@ and _before_ loading any pipelines.
> nlp = spacy.load("en_core_web_sm")
> ```
| Name | Description |
| ----------- | --------------- |
| **RETURNS** | `True` ~~bool~~ |
| Name | Description |
| ----------- | ------------------------------------------------ |
| `gpu_id` | Device index to select. Defaults to `0`. ~~int~~ |
| **RETURNS** | `True` ~~bool~~ |
## displaCy {#displacy source="spacy/displacy"}

View File

@ -356,6 +356,18 @@ that training configs are complete and experiments fully reproducible.
</Infobox>
Note that when using a PyTorch or Tensorflow model, it is recommended to set the GPU
memory allocator accordingly. When `gpu_allocator` is set to "pytorch" or
"tensorflow" in the training config, cupy will allocate memory via those respective libraries,
preventing OOM errors when there's available memory sitting in the other
library's pool.
```ini
### config.cfg (excerpt)
[training]
gpu_allocator = "pytorch"
```
## Custom models with Thinc {#thinc}
Of course it's also possible to define the `Model` from the previous section