Add distillation loop (#12542)

* Add distillation initialization and loop

* Fix up configuration keys

* Add docstring

* Type annotations

* init_nlp_distill -> init_nlp_student

* Do not resolve dot name distill corpus in initialization

(Since we don't use it.)

* student: do not request use of optimizer in student pipe

We apply finish up the updates once in the training loop instead.

Also add the necessary logic to `Language.distill` to mirror
`Language.update`.

* Correctly determine sort key in subdivide_batch

* Fix _distill_loop docstring wrt. stopping condition

* _distill_loop: fix distill_data docstring

Make similar changes in train_while_improving, since it also had
incorrect types and missing type annotations.

* Move `set_{gpu_allocator,seed}_from_config` to spacy.util

* Update Language.update docs for the sgd argument

* Type annotation

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>

---------

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
This commit is contained in:
Daniël de Kok 2023-04-21 13:49:40 +02:00 committed by GitHub
parent 5d0f48fe69
commit 8a5814bf2c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 547 additions and 48 deletions

View File

@ -1024,7 +1024,7 @@ class Language:
examples: Iterable[Example],
*,
drop: float = 0.0,
sgd: Optional[Optimizer] = None,
sgd: Union[Optimizer, None, Literal[False]] = None,
losses: Optional[Dict[str, float]] = None,
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
exclude: Iterable[str] = SimpleFrozenList(),
@ -1037,7 +1037,9 @@ class Language:
(teacher) and predicted (student) docs must have the same number of
tokens and the same orthography.
drop (float): The dropout rate.
sgd (Optional[Optimizer]): An optimizer.
sgd (Union[Optimizer, None, Literal[False]]): An optimizer. Will
be created via create_optimizer if 'None'. No optimizer will
be used when set to 'False'.
losses (Optional(Dict[str, float])): Dictionary to update with the loss,
keyed by component.
component_cfg (Optional[Dict[str, Dict[str, Any]]]): Config parameters
@ -1107,11 +1109,23 @@ class Language:
student_proc.distill(
teacher_pipe,
examples,
sgd=sgd,
sgd=None,
losses=losses,
**component_cfg[student_name],
)
# Only finish the update after all component updates are done. Some
# components may share weights (such as tok2vec) and we only want
# to apply weight updates after all gradients are accumulated.
for student_name, student_proc in self.pipeline:
if (
student_name not in exclude
and isinstance(student_proc, ty.DistillableComponent)
and student_proc.is_distillable
and sgd not in (None, False)
):
student_proc.finish_update(sgd)
return losses
def disable_pipes(self, *names) -> "DisabledPipes":
@ -1837,7 +1851,7 @@ class Language:
# using the nlp.config with all defaults.
config = util.copy_config(config)
orig_pipeline = config.pop("components", {})
orig_distill = config.pop("distill", None)
orig_distill = config.pop("distillation", None)
orig_pretraining = config.pop("pretraining", None)
config["components"] = {}
if auto_fill:
@ -1847,8 +1861,8 @@ class Language:
filled["components"] = orig_pipeline
config["components"] = orig_pipeline
if orig_distill is not None:
filled["distill"] = orig_distill
config["distill"] = orig_distill
filled["distillation"] = orig_distill
config["distillation"] = orig_distill
if orig_pretraining is not None:
filled["pretraining"] = orig_pretraining
config["pretraining"] = orig_pretraining

View File

@ -462,7 +462,7 @@ CONFIG_SCHEMAS = {
"training": ConfigSchemaTraining,
"pretraining": ConfigSchemaPretrain,
"initialize": ConfigSchemaInit,
"distill": ConfigSchemaDistill,
"distillation": ConfigSchemaDistill,
}

View File

@ -0,0 +1,111 @@
from typing import Callable, Iterable, Iterator
import pytest
from spacy import Language
from spacy.training import Example
from spacy.training.initialize import init_nlp_student
from spacy.training.loop import distill, train
from spacy.util import load_model_from_config, registry
from thinc.api import Config
@pytest.fixture
def config_str():
return """
[nlp]
lang = "en"
pipeline = ["senter"]
disabled = []
before_creation = null
after_creation = null
after_pipeline_creation = null
batch_size = 1000
tokenizer = {"@tokenizers":"spacy.Tokenizer.v1"}
[components]
[components.senter]
factory = "senter"
[training]
dev_corpus = "corpora.dev"
train_corpus = "corpora.train"
max_steps = 50
seed = 1
gpu_allocator = null
[distillation]
corpus = "corpora.train"
dropout = 0.1
max_epochs = 0
max_steps = 50
student_to_teacher = {}
[distillation.batcher]
@batchers = "spacy.batch_by_words.v1"
size = 3000
discard_oversize = false
tolerance = 0.2
[distillation.optimizer]
@optimizers = "Adam.v1"
beta1 = 0.9
beta2 = 0.999
L2_is_weight_decay = true
L2 = 0.01
grad_clip = 1.0
use_averages = true
eps = 1e-8
learn_rate = 1e-4
[corpora]
[corpora.dev]
@readers = "sentence_corpus"
[corpora.train]
@readers = "sentence_corpus"
"""
SENT_STARTS = [0] * 14
SENT_STARTS[0] = 1
SENT_STARTS[5] = 1
SENT_STARTS[9] = 1
TRAIN_DATA = [
(
"I like green eggs. Eat blue ham. I like purple eggs.",
{"sent_starts": SENT_STARTS},
),
(
"She likes purple eggs. They hate ham. You like yellow eggs.",
{"sent_starts": SENT_STARTS},
),
]
@pytest.mark.slow
def test_distill_loop(config_str):
@registry.readers("sentence_corpus")
def create_sentence_corpus() -> Callable[[Language], Iterable[Example]]:
return SentenceCorpus()
class SentenceCorpus:
def __call__(self, nlp: Language) -> Iterator[Example]:
for t in TRAIN_DATA:
yield Example.from_dict(nlp.make_doc(t[0]), t[1])
orig_config = Config().from_str(config_str)
teacher = load_model_from_config(orig_config, auto_fill=True, validate=True)
teacher.initialize()
train(teacher)
orig_config = Config().from_str(config_str)
student = init_nlp_student(orig_config, teacher)
student.initialize()
distill(teacher, student)
doc = student(TRAIN_DATA[0][0])
assert doc.sents[0].text == "I like green eggs."
assert doc.sents[1].text == "Eat blue ham."
assert doc.sents[2].text == "I like purple eggs."

View File

@ -1,6 +1,5 @@
from typing import Union, Dict, Optional, Any, IO, TYPE_CHECKING
from thinc.api import Config, fix_random_seed, set_gpu_allocator
from thinc.api import ConfigValidationError
from thinc.api import Config, ConfigValidationError
from pathlib import Path
import srsly
import numpy
@ -15,10 +14,11 @@ from .pretrain import get_tok2vec_ref
from ..lookups import Lookups
from ..vectors import Vectors, Mode as VectorsMode
from ..errors import Errors, Warnings
from ..schemas import ConfigSchemaTraining
from ..schemas import ConfigSchemaDistill, ConfigSchemaTraining
from ..util import registry, load_model_from_config, resolve_dot_names, logger
from ..util import load_model, ensure_path, get_sourced_components
from ..util import OOV_RANK, DEFAULT_OOV_PROB
from ..util import set_gpu_allocator_from_config, set_seed_from_config
if TYPE_CHECKING:
from ..language import Language # noqa: F401
@ -27,15 +27,8 @@ if TYPE_CHECKING:
def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
raw_config = config
config = raw_config.interpolate()
if "seed" not in config["training"]:
raise ValueError(Errors.E1015.format(value="[training] seed"))
if "gpu_allocator" not in config["training"]:
raise ValueError(Errors.E1015.format(value="[training] gpu_allocator"))
if config["training"]["seed"] is not None:
fix_random_seed(config["training"]["seed"])
allocator = config["training"]["gpu_allocator"]
if use_gpu >= 0 and allocator:
set_gpu_allocator(allocator)
set_seed_from_config(config)
set_gpu_allocator_from_config(config, use_gpu)
# Use original config here before it's resolved to functions
sourced = get_sourced_components(config)
nlp = load_model_from_config(raw_config, auto_fill=True)
@ -101,6 +94,102 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
return nlp
def init_nlp_student(
config: Config, teacher: "Language", *, use_gpu: int = -1
) -> "Language":
"""Initialize student pipeline for distillation.
config (Config): Student model configuration.
teacher (Language): The teacher pipeline to distill from.
use_gpu (int): Whether to train on GPU. Make sure to call require_gpu
before calling this function.
"""
raw_config = config
config = raw_config.interpolate()
set_seed_from_config(config)
set_gpu_allocator_from_config(config, use_gpu)
# Use original config here before it's resolved to functions
sourced = get_sourced_components(config)
nlp = load_model_from_config(raw_config, auto_fill=True)
logger.info("Set up nlp object from config")
config = nlp.config.interpolate()
# Resolve all training-relevant sections using the filled nlp config
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
D = registry.resolve(config["distillation"], schema=ConfigSchemaDistill)
dot_names = [T["dev_corpus"]]
if not isinstance(D["corpus"], str):
raise ConfigValidationError(
desc=Errors.E897.format(field="distillation.corpus", type=type(D["corpus"]))
)
if not isinstance(T["dev_corpus"], str):
raise ConfigValidationError(
desc=Errors.E897.format(
field="training.dev_corpus", type=type(T["dev_corpus"])
)
)
(dev_corpus,) = resolve_dot_names(config, dot_names)
optimizer = T["optimizer"]
# Components that shouldn't be updated during training
frozen_components = T["frozen_components"]
# Sourced components that require resume_training
resume_components = [p for p in sourced if p not in frozen_components]
logger.info(f"Pipeline: {nlp.pipe_names}")
if resume_components:
with nlp.select_pipes(enable=resume_components):
logger.info(f"Resuming training for: {resume_components}")
nlp.resume_training(sgd=optimizer)
# Make sure that listeners are defined before initializing further
nlp._link_components()
# Get teacher labels to initialize student with.
student_to_teacher = D["student_to_teacher"]
teacher_pipes = dict(teacher.pipeline)
labels = {}
for name, pipe in nlp.pipeline:
# Copy teacher labels.
teacher_pipe_name = (
student_to_teacher[name] if name in student_to_teacher else name
)
teacher_pipe = teacher_pipes.get(teacher_pipe_name, None)
if (
teacher_pipe is not None
and getattr(teacher_pipe, "label_data", None) is not None
):
labels[name] = teacher_pipe.label_data # type: ignore[attr-defined]
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
# Initialize on the dev corpus, since the distillation corpus does
# usually not have labels. Since we copy the labels from the teacher
# pipe, the dev data does not have to be exhaustive.
if T["max_epochs"] == -1:
sample_size = 100
logger.debug(
f"Due to streamed train corpus, using only first {sample_size} "
f"examples for initialization. If necessary, provide all labels "
f"in [initialize]. More info: https://spacy.io/api/cli#init_labels"
)
nlp.initialize(lambda: islice(dev_corpus(nlp), sample_size), sgd=optimizer)
else:
nlp.initialize(lambda: dev_corpus(nlp), sgd=optimizer, labels=labels)
logger.info(f"Initialized pipeline components: {nlp.pipe_names}")
# Detect components with listeners that are not frozen consistently
for name, proc in nlp.pipeline:
for listener in getattr(
proc, "listening_components", []
): # e.g. tok2vec/transformer
# Don't warn about components not in the pipeline
if listener not in nlp.pipe_names:
continue
if listener in frozen_components and name not in frozen_components:
logger.warning(Warnings.W087.format(name=name, listener=listener))
# We always check this regardless, in case user freezes tok2vec
if listener not in frozen_components and name in frozen_components:
if name not in T["annotating_components"]:
logger.warning(Warnings.W086.format(name=name, listener=listener))
return nlp
def init_vocab(
nlp: "Language",
*,

View File

@ -2,16 +2,20 @@ from typing import List, Callable, Tuple, Dict, Iterable, Union, Any, IO
from typing import Optional, TYPE_CHECKING
from pathlib import Path
from timeit import default_timer as timer
from thinc.api import Optimizer, Config, constant, fix_random_seed, set_gpu_allocator
from thinc.api import Optimizer, Config, constant
from wasabi import Printer
import random
import sys
import shutil
from .example import Example
from ..schemas import ConfigSchemaTraining
from ..schemas import ConfigSchemaDistill, ConfigSchemaTraining
from ..errors import Errors
from ..tokens.doc import Doc
from .. import ty
from ..util import resolve_dot_names, registry, logger
from ..util import set_gpu_allocator_from_config, set_seed_from_config
if TYPE_CHECKING:
from ..language import Language # noqa: F401
@ -21,6 +25,129 @@ DIR_MODEL_BEST = "model-best"
DIR_MODEL_LAST = "model-last"
def distill(
teacher: "Language",
student: "Language",
output_path: Optional[Path] = None,
*,
use_gpu: int = -1,
stdout: IO = sys.stdout,
stderr: IO = sys.stderr,
) -> Tuple["Language", Optional[Path]]:
"""Distill a student pipeline from a teacher pipeline.
teacher (Language): The teacher pipeline to distill from.
student (Language): The student pipeline to distill into.
output_path (Optional[Path]): Optional output path to save the student
model to.
use_gpu (int): Whether to train on GPU. Make sure to call require_gpu
before calling this function.
stdout (file): A file-like object to write output messages. To disable
printing, set to io.StringIO.
stderr (file): A second file-like object to write output messages. To disable
printing, set to io.StringIO.
RETURNS (tuple): The final student nlp object and the path to the exported
student model.
"""
# We use no_print here so we can respect the stdout/stderr options.
msg = Printer(no_print=True)
# Create iterator, which yields out info after each optimization step.
config = student.config.interpolate()
set_seed_from_config(config)
set_gpu_allocator_from_config(config, use_gpu)
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
D = registry.resolve(config["distillation"], schema=ConfigSchemaDistill)
dot_names = [D["corpus"], T["dev_corpus"]]
distill_corpus, dev_corpus = resolve_dot_names(config, dot_names)
optimizer = D["optimizer"]
score_weights = T["score_weights"]
batcher = D["batcher"]
train_logger = T["logger"]
before_to_disk = create_before_to_disk_callback(T["before_to_disk"])
before_update = T["before_update"]
student_to_teacher = D["student_to_teacher"]
# Helper function to save checkpoints. This is a closure for convenience,
# to avoid passing in all the args all the time.
def save_checkpoint(is_best):
with student.use_params(optimizer.averages):
before_to_disk(student).to_disk(output_path / DIR_MODEL_LAST)
if is_best:
# Avoid saving twice (saving will be more expensive than
# the dir copy)
if (output_path / DIR_MODEL_BEST).exists():
shutil.rmtree(output_path / DIR_MODEL_BEST)
shutil.copytree(output_path / DIR_MODEL_LAST, output_path / DIR_MODEL_BEST)
# Components that shouldn't be updated during training
frozen_components = T["frozen_components"]
# Components that should set annotations on update
annotating_components = T["annotating_components"]
# Create iterator, which yields out info after each optimization step.
training_step_iterator = _distill_loop(
teacher,
student,
optimizer,
create_distill_batches(student, distill_corpus, batcher, D["max_epochs"]),
create_evaluation_callback(student, dev_corpus, score_weights),
dropout=D["dropout"],
accumulate_gradient=T["accumulate_gradient"],
max_steps=D["max_steps"],
eval_frequency=T["eval_frequency"],
exclude=frozen_components,
annotating_components=annotating_components,
before_update=before_update,
student_to_teacher=student_to_teacher,
)
clean_output_dir(output_path)
stdout.write(msg.info(f"Teacher pipeline: {teacher.pipe_names}") + "\n")
stdout.write(msg.info(f"Student pipeline: {student.pipe_names}") + "\n")
if frozen_components:
stdout.write(msg.info(f"Frozen components: {frozen_components}") + "\n")
if annotating_components:
stdout.write(
msg.info(f"Set annotations on update for: {annotating_components}") + "\n"
)
stdout.write(msg.info(f"Initial learn rate: {optimizer.learn_rate(step=0)}") + "\n")
with student.select_pipes(disable=frozen_components):
log_step, finalize_logger = train_logger(student, stdout, stderr)
try:
for batch, info, is_best_checkpoint in training_step_iterator:
if is_best_checkpoint is not None:
with student.select_pipes(disable=frozen_components):
update_meta(T, student, info)
if output_path is not None:
save_checkpoint(is_best_checkpoint)
info["output_path"] = str(output_path / DIR_MODEL_LAST)
log_step(info if is_best_checkpoint is not None else None)
except Exception as e:
if output_path is not None:
stdout.write(
msg.warn(
f"Aborting and saving the final best model. "
f"Encountered exception: {repr(e)}"
)
+ "\n"
)
raise e
finally:
finalize_logger()
if output_path is not None:
save_checkpoint(False)
# This will only run if we did't hit an error
if optimizer.averages:
student.use_params(optimizer.averages)
if output_path is not None:
stdout.write(
msg.good("Saved pipeline to output directory", output_path / DIR_MODEL_LAST)
+ "\n"
)
return (student, output_path / DIR_MODEL_LAST)
else:
return (student, None)
def train(
nlp: "Language",
output_path: Optional[Path] = None,
@ -46,11 +173,8 @@ def train(
msg = Printer(no_print=True)
# Create iterator, which yields out info after each optimization step.
config = nlp.config.interpolate()
if config["training"]["seed"] is not None:
fix_random_seed(config["training"]["seed"])
allocator = config["training"]["gpu_allocator"]
if use_gpu >= 0 and allocator:
set_gpu_allocator(allocator)
set_seed_from_config(config)
set_gpu_allocator_from_config(config, use_gpu)
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
dot_names = [T["train_corpus"], T["dev_corpus"]]
train_corpus, dev_corpus = resolve_dot_names(config, dot_names)
@ -139,11 +263,131 @@ def train(
return (nlp, None)
def _distill_loop(
teacher: "Language",
student: "Language",
optimizer: Optimizer,
distill_data: Iterable[List[Example]],
evaluate: Callable[[], Tuple[float, Dict[str, float]]],
*,
dropout: float,
eval_frequency: int,
accumulate_gradient: int,
max_steps: int,
exclude: List[str],
annotating_components: List[str],
before_update: Optional[Callable[["Language", Dict[str, Any]], None]],
student_to_teacher: Dict[str, str],
):
"""Distill until the data is exhausted or the maximum number of steps
has been reached. Works as a generator, with each iteration yielding
a tuple `(batch, info, is_best_checkpoint)`, where info is a dict, and
is_best_checkpoint is in [True, False, None] -- None indicating that
the iteration was not evaluated as a checkpoint. The evaluation is
conducted by calling the evaluate callback.
Positional arguments:
teacher (Language): The teacher pipeline to distill from.
student (Language): The student pipeline to distill into.
optimizer: The optimizer callable.
distill_data (Iterable[List[Example]]): A generator of batches,
with the distillation data. The distillation data iterable
needs to take care of iterating over the epochs and shuffling.
evaluate (Callable[[], Tuple[float, Any]]): A callback to perform evaluation.
The callback should take no arguments and return a tuple
`(main_score, other_scores)`. The main_score should be a float where
higher is better. other_scores can be any object.
Every iteration, the function yields out a tuple with:
* batch: A list of Example objects.
* info: A dict with various information about the last update (see below).
* is_best_checkpoint: A value in None, False, True, indicating whether this
was the best evaluation so far. You should use this to save the model
checkpoints during training. If None, evaluation was not conducted on
that iteration. False means evaluation was conducted, but a previous
evaluation was better.
The info dict provides the following information:
epoch (int): How many passes over the data have been completed.
step (int): How many steps have been completed.
score (float): The main score from the last evaluation.
other_scores: : The other scores from the last evaluation.
losses: The accumulated losses throughout training.
checkpoints: A list of previous results, where each result is a
(score, step, epoch) tuple.
"""
if isinstance(dropout, float):
dropouts = constant(dropout)
else:
dropouts = dropout
results = []
losses: Dict[str, float] = {}
words_seen = 0
start_time = timer()
for step, (epoch, batch) in enumerate(distill_data):
if before_update:
before_update_args = {"step": step, "epoch": epoch}
before_update(student, before_update_args)
dropout = dropouts(optimizer.step)
for subbatch in subdivide_batch(batch, accumulate_gradient):
student.distill(
teacher,
subbatch,
drop=dropout,
losses=losses,
sgd=False,
exclude=exclude,
annotates=annotating_components,
student_to_teacher=student_to_teacher,
)
# TODO: refactor this so we don't have to run it separately in here
for student_name, student_proc in student.pipeline:
if (
student_name not in exclude
and isinstance(student_proc, ty.DistillableComponent)
and student_proc.is_distillable
and student_proc.model not in (False, None) # type: ignore[attr-defined]
):
student_proc.finish_update(optimizer) # type: ignore[attr-defined]
optimizer.step_schedules()
if not (step % eval_frequency):
if optimizer.averages:
with student.use_params(optimizer.averages):
score, other_scores = evaluate()
else:
score, other_scores = evaluate()
optimizer.last_score = score # type: ignore[assignment]
results.append((score, step))
is_best_checkpoint = score == max(results)[0]
else:
score, other_scores = (None, None)
is_best_checkpoint = None
words_seen += sum(len(eg) for eg in batch)
info = {
"epoch": epoch,
"step": step,
"score": score,
"other_scores": other_scores,
"losses": losses,
"checkpoints": results,
"seconds": int(timer() - start_time),
"words": words_seen,
}
yield batch, info, is_best_checkpoint
if is_best_checkpoint is not None:
losses = {}
# Stop if we've exhausted our max steps (if specified)
if max_steps and step >= max_steps:
break
def train_while_improving(
nlp: "Language",
optimizer: Optimizer,
train_data,
evaluate,
train_data: Iterable[List[Example]],
evaluate: Callable[[], Tuple[float, Dict[str, float]]],
*,
dropout: float,
eval_frequency: int,
@ -163,10 +407,9 @@ def train_while_improving(
Positional arguments:
nlp: The spaCy pipeline to evaluate.
optimizer: The optimizer callable.
train_data (Iterable[Batch]): A generator of batches, with the training
data. Each batch should be a Sized[Tuple[Input, Annot]]. The training
data iterable needs to take care of iterating over the epochs and
shuffling.
train_data (Iterable[List[Example]]): A generator of batches, with the
training data. The training data iterable needs to take care of
iterating over the epochs and shuffling.
evaluate (Callable[[], Tuple[float, Any]]): A callback to perform evaluation.
The callback should take no arguments and return a tuple
`(main_score, other_scores)`. The main_score should be a float where
@ -230,7 +473,7 @@ def train_while_improving(
score, other_scores = evaluate()
else:
score, other_scores = evaluate()
optimizer.last_score = score
optimizer.last_score = score # type: ignore[assignment]
results.append((score, step))
is_best_checkpoint = score == max(results)[0]
else:
@ -262,9 +505,15 @@ def train_while_improving(
break
def subdivide_batch(batch, accumulate_gradient):
def subdivide_batch(
batch: Union[Iterable[Doc], Iterable[Example]], accumulate_gradient: int
):
batch = list(batch)
batch.sort(key=lambda eg: len(eg.predicted))
if len(batch):
if isinstance(batch[0], Example):
batch.sort(key=lambda eg: len(eg.predicted))
else:
batch.sort(key=lambda doc: len(doc))
sub_len = len(batch) // accumulate_gradient
start = 0
for i in range(accumulate_gradient):
@ -309,6 +558,22 @@ def create_evaluation_callback(
return evaluate
def create_distill_batches(
nlp: "Language",
corpus: Callable[["Language"], Iterable[Example]],
batcher: Callable[[Iterable[Example]], Iterable[List[Example]]],
max_epochs: int,
):
"""Create distillation batches. In contrast to training, the corpus
is normally too large to load into memory and shuffle."""
epoch = 0
while max_epochs < 1 or epoch != max_epochs:
examples = corpus(nlp)
for batch in batcher(examples):
yield epoch, batch
epoch += 1
def create_train_batches(
nlp: "Language",
corpus: Callable[["Language"], Iterable[Example]],

View File

@ -11,6 +11,7 @@ from pathlib import Path
import thinc
from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer
from thinc.api import ConfigValidationError, Model, constant as constant_schedule
from thinc.api import fix_random_seed, set_gpu_allocator
import functools
import itertools
import numpy
@ -1790,3 +1791,22 @@ def find_available_port(start: int, host: str, auto_select: bool = False) -> int
# if we get here, the port changed
warnings.warn(Warnings.W124.format(host=host, port=start, serve_port=port))
return port
def set_gpu_allocator_from_config(config: Config, use_gpu: int):
"""Change the global GPU allocator based to the value in
the configuration."""
if "gpu_allocator" not in config["training"]:
raise ValueError(Errors.E1015.format(value="[training] gpu_allocator"))
allocator = config["training"]["gpu_allocator"]
if use_gpu >= 0 and allocator:
set_gpu_allocator(allocator)
def set_seed_from_config(config: Config):
"""Set the random number generator seed to the value in
the configuration."""
if "seed" not in config["training"]:
raise ValueError(Errors.E1015.format(value="[training] seed"))
if config["training"]["seed"] is not None:
fix_random_seed(config["training"]["seed"])

View File

@ -347,19 +347,19 @@ Distill the models in a student pipeline from a teacher pipeline.
> student.distill(teacher, examples, sgd=optimizer)
> ```
| Name | Description |
| -------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `teacher` | The teacher pipeline to distill from. ~~Language~~ |
| `examples` | A batch of [`Example`](/api/example) distillation examples. The reference (teacher) and predicted (student) docs must have the same number of tokens and orthography. ~~Iterable[Example]~~ |
| _keyword-only_ | |
| `drop` | The dropout rate. ~~float~~ |
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
| `losses` | Dictionary to update with the loss, keyed by pipeline component. ~~Optional[Dict[str, float]]~~ |
| `component_cfg` | Optional dictionary of keyword arguments for components, keyed by component names. Defaults to `None`. ~~Optional[Dict[str, Dict[str, Any]]]~~ |
| `exclude` | Names of components that shouldn't be updated. Defaults to `[]`. ~~Iterable[str]~~ |
| `annotates` | Names of components that should set annotations on the prediced examples after updating. Defaults to `[]`. ~~Iterable[str]~~ |
| `student_to_teacher` | Map student component names to teacher component names, only necessary when the names differ. Defaults to `None`. ~~Optional[Dict[str, str]]~~ |
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
| Name | Description |
| -------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `teacher` | The teacher pipeline to distill from. ~~Language~~ |
| `examples` | A batch of [`Example`](/api/example) distillation examples. The reference (teacher) and predicted (student) docs must have the same number of tokens and orthography. ~~Iterable[Example]~~ |
| _keyword-only_ | |
| `drop` | The dropout rate. ~~float~~ |
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if `None`. No optimizer will be used when set to `False`. Defaults to `None`. ~~Union[Optimizer, None, Literal[False]]~~ |
| `losses` | Dictionary to update with the loss, keyed by pipeline component. ~~Optional[Dict[str, float]]~~ |
| `component_cfg` | Optional dictionary of keyword arguments for components, keyed by component names. Defaults to `None`. ~~Optional[Dict[str, Dict[str, Any]]]~~ |
| `exclude` | Names of components that shouldn't be updated. Defaults to `[]`. ~~Iterable[str]~~ |
| `annotates` | Names of components that should set annotations on the prediced examples after updating. Defaults to `[]`. ~~Iterable[str]~~ |
| `student_to_teacher` | Map student component names to teacher component names, only necessary when the names differ. Defaults to `None`. ~~Optional[Dict[str, str]]~~ |
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
## Language.rehearse {id="rehearse",tag="method,experimental",version="3"}