Add distillation initialization and loop

This commit is contained in:
Daniël de Kok 2023-04-18 13:55:55 +02:00
parent 5d0f48fe69
commit 9a72ea0b91
3 changed files with 484 additions and 12 deletions

View File

@ -0,0 +1,111 @@
from typing import Callable, Iterable, Iterator
import pytest
from spacy import Language
from spacy.training import Example
from spacy.training.initialize import init_nlp_distill
from spacy.training.loop import distill, train
from spacy.util import load_model_from_config, registry
from thinc.api import Config
@pytest.fixture
def config_str():
return """
[nlp]
lang = "en"
pipeline = ["senter"]
disabled = []
before_creation = null
after_creation = null
after_pipeline_creation = null
batch_size = 1000
tokenizer = {"@tokenizers":"spacy.Tokenizer.v1"}
[components]
[components.senter]
factory = "senter"
[training]
dev_corpus = "corpora.dev"
train_corpus = "corpora.train"
max_steps = 50
seed = 1
gpu_allocator = null
[distillation]
corpus = "corpora.train"
dropout = 0.1
max_epochs = 0
max_steps = 50
student_to_teacher = {}
[distillation.batcher]
@batchers = "spacy.batch_by_words.v1"
size = 3000
discard_oversize = false
tolerance = 0.2
[distillation.optimizer]
@optimizers = "Adam.v1"
beta1 = 0.9
beta2 = 0.999
L2_is_weight_decay = true
L2 = 0.01
grad_clip = 1.0
use_averages = true
eps = 1e-8
learn_rate = 1e-4
[corpora]
[corpora.dev]
@readers = "sentence_corpus"
[corpora.train]
@readers = "sentence_corpus"
"""
SENT_STARTS = [0] * 14
SENT_STARTS[0] = 1
SENT_STARTS[5] = 1
SENT_STARTS[9] = 1
TRAIN_DATA = [
(
"I like green eggs. Eat blue ham. I like purple eggs.",
{"sent_starts": SENT_STARTS},
),
(
"She likes purple eggs. They hate ham. You like yellow eggs.",
{"sent_starts": SENT_STARTS},
),
]
@pytest.mark.slow
def test_distill_loop(config_str):
@registry.readers("sentence_corpus")
def create_sentence_corpus() -> Callable[[Language], Iterable[Example]]:
return SentenceCorpus()
class SentenceCorpus:
def __call__(self, nlp: Language) -> Iterator[Example]:
for t in TRAIN_DATA:
yield Example.from_dict(nlp.make_doc(t[0]), t[1])
orig_config = Config().from_str(config_str)
teacher = load_model_from_config(orig_config, auto_fill=True, validate=True)
teacher.initialize()
train(teacher)
orig_config = Config().from_str(config_str)
student = init_nlp_distill(orig_config, teacher)
student.initialize()
distill(teacher, student)
doc = student(TRAIN_DATA[0][0])
assert doc.sents[0].text == "I like green eggs."
assert doc.sents[1].text == "Eat blue ham."
assert doc.sents[2].text == "I like purple eggs."

View File

@ -15,7 +15,7 @@ from .pretrain import get_tok2vec_ref
from ..lookups import Lookups
from ..vectors import Vectors, Mode as VectorsMode
from ..errors import Errors, Warnings
from ..schemas import ConfigSchemaTraining
from ..schemas import ConfigSchemaDistill, ConfigSchemaTraining
from ..util import registry, load_model_from_config, resolve_dot_names, logger
from ..util import load_model, ensure_path, get_sourced_components
from ..util import OOV_RANK, DEFAULT_OOV_PROB
@ -27,15 +27,8 @@ if TYPE_CHECKING:
def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
raw_config = config
config = raw_config.interpolate()
if "seed" not in config["training"]:
raise ValueError(Errors.E1015.format(value="[training] seed"))
if "gpu_allocator" not in config["training"]:
raise ValueError(Errors.E1015.format(value="[training] gpu_allocator"))
if config["training"]["seed"] is not None:
fix_random_seed(config["training"]["seed"])
allocator = config["training"]["gpu_allocator"]
if use_gpu >= 0 and allocator:
set_gpu_allocator(allocator)
_set_seed_from_config(config)
_set_gpu_allocator_from_config(config, use_gpu)
# Use original config here before it's resolved to functions
sourced = get_sourced_components(config)
nlp = load_model_from_config(raw_config, auto_fill=True)
@ -101,6 +94,110 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
return nlp
def _set_gpu_allocator_from_config(config, use_gpu):
if "gpu_allocator" not in config["training"]:
raise ValueError(Errors.E1015.format(value="[training] gpu_allocator"))
allocator = config["training"]["gpu_allocator"]
if use_gpu >= 0 and allocator:
set_gpu_allocator(allocator)
def _set_seed_from_config(config):
if "seed" not in config["training"]:
raise ValueError(Errors.E1015.format(value="[training] seed"))
if config["training"]["seed"] is not None:
fix_random_seed(config["training"]["seed"])
def init_nlp_distill(
config: Config, teacher: "Language", *, use_gpu: int = -1
) -> "Language":
raw_config = config
config = raw_config.interpolate()
_set_seed_from_config(config)
_set_gpu_allocator_from_config(config, use_gpu)
# Use original config here before it's resolved to functions
sourced = get_sourced_components(config)
nlp = load_model_from_config(raw_config, auto_fill=True)
logger.info("Set up nlp object from config")
config = nlp.config.interpolate()
# Resolve all training-relevant sections using the filled nlp config
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
D = registry.resolve(config["distill"], schema=ConfigSchemaDistill)
dot_names = [D["corpus"], T["dev_corpus"]]
if not isinstance(D["corpus"], str):
raise ConfigValidationError(
desc=Errors.E897.format(
field="distill.corpus", type=type(D["corpus"])
)
)
if not isinstance(T["dev_corpus"], str):
raise ConfigValidationError(
desc=Errors.E897.format(
field="training.dev_corpus", type=type(T["dev_corpus"])
)
)
distill_corpus, dev_corpus = resolve_dot_names(config, dot_names)
optimizer = T["optimizer"]
# Components that shouldn't be updated during training
frozen_components = T["frozen_components"]
# Sourced components that require resume_training
resume_components = [p for p in sourced if p not in frozen_components]
logger.info(f"Pipeline: {nlp.pipe_names}")
if resume_components:
with nlp.select_pipes(enable=resume_components):
logger.info(f"Resuming training for: {resume_components}")
nlp.resume_training(sgd=optimizer)
# Make sure that listeners are defined before initializing further
nlp._link_components()
# Get teacher labels to initialize student with.
student_to_teacher = D["student_to_teacher"]
teacher_pipes = dict(teacher.pipeline)
labels = {}
for name, pipe in nlp.pipeline:
# Copy teacher labels.
teacher_pipe_name = student_to_teacher[name] if name in student_to_teacher else name
teacher_pipe = teacher_pipes.get(teacher_pipe_name, None)
if (
teacher_pipe is not None
and getattr(teacher_pipe, "label_data", None) is not None
):
labels[name] = teacher_pipe.label_data # type: ignore[attr-defined]
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
# Initialize on the dev corpus, since the distillation corpus does
# usually not have labels. Since we copy the labels from the teacher
# pipe, the dev data does not have to be exhaustive.
if T["max_epochs"] == -1:
sample_size = 100
logger.debug(
f"Due to streamed train corpus, using only first {sample_size} "
f"examples for initialization. If necessary, provide all labels "
f"in [initialize]. More info: https://spacy.io/api/cli#init_labels"
)
nlp.initialize(lambda: islice(dev_corpus(nlp), sample_size), sgd=optimizer)
else:
nlp.initialize(lambda: dev_corpus(nlp), sgd=optimizer, labels=labels)
logger.info(f"Initialized pipeline components: {nlp.pipe_names}")
# Detect components with listeners that are not frozen consistently
for name, proc in nlp.pipeline:
for listener in getattr(
proc, "listening_components", []
): # e.g. tok2vec/transformer
# Don't warn about components not in the pipeline
if listener not in nlp.pipe_names:
continue
if listener in frozen_components and name not in frozen_components:
logger.warning(Warnings.W087.format(name=name, listener=listener))
# We always check this regardless, in case user freezes tok2vec
if listener not in frozen_components and name in frozen_components:
if name not in T["annotating_components"]:
logger.warning(Warnings.W086.format(name=name, listener=listener))
return nlp
def init_vocab(
nlp: "Language",
*,

View File

@ -9,8 +9,9 @@ import sys
import shutil
from .example import Example
from ..schemas import ConfigSchemaTraining
from ..schemas import ConfigSchemaDistill, ConfigSchemaTraining
from ..errors import Errors
from ..tokens.doc import Doc
from ..util import resolve_dot_names, registry, logger
if TYPE_CHECKING:
@ -21,6 +22,132 @@ DIR_MODEL_BEST = "model-best"
DIR_MODEL_LAST = "model-last"
def distill(
teacher: "Language",
student: "Language",
output_path: Optional[Path] = None,
*,
use_gpu: int = -1,
stdout: IO = sys.stdout,
stderr: IO = sys.stderr,
) -> Tuple["Language", Optional[Path]]:
"""Distill a student pipeline from a teacher pipeline.
teacher (Language): The teacher pipeline to distill from.
student (Language): The student pipeline to distill into.
output_path (Optional[Path]): Optional output path to save the student
model to.
use_gpu (int): Whether to train on GPU. Make sure to call require_gpu
before calling this function.
stdout (file): A file-like object to write output messages. To disable
printing, set to io.StringIO.
stderr (file): A second file-like object to write output messages. To disable
printing, set to io.StringIO.
RETURNS (tuple): The final student nlp object and the path to the exported
student model.
"""
# We use no_print here so we can respect the stdout/stderr options.
msg = Printer(no_print=True)
# Create iterator, which yields out info after each optimization step.
config = student.config.interpolate()
if config["training"]["seed"] is not None:
fix_random_seed(config["training"]["seed"])
allocator = config["training"]["gpu_allocator"]
if use_gpu >= 0 and allocator:
set_gpu_allocator(allocator)
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
D = registry.resolve(config["distill"], schema=ConfigSchemaDistill)
dot_names = [D["corpus"], T["dev_corpus"]]
distill_corpus, dev_corpus = resolve_dot_names(config, dot_names)
optimizer = D["optimizer"]
score_weights = T["score_weights"]
batcher = D["batcher"]
train_logger = T["logger"]
before_to_disk = create_before_to_disk_callback(T["before_to_disk"])
before_update = T["before_update"]
student_to_teacher = D["student_to_teacher"]
# Helper function to save checkpoints. This is a closure for convenience,
# to avoid passing in all the args all the time.
def save_checkpoint(is_best):
with student.use_params(optimizer.averages):
before_to_disk(student).to_disk(output_path / DIR_MODEL_LAST)
if is_best:
# Avoid saving twice (saving will be more expensive than
# the dir copy)
if (output_path / DIR_MODEL_BEST).exists():
shutil.rmtree(output_path / DIR_MODEL_BEST)
shutil.copytree(output_path / DIR_MODEL_LAST, output_path / DIR_MODEL_BEST)
# Components that shouldn't be updated during training
frozen_components = T["frozen_components"]
# Components that should set annotations on update
annotating_components = T["annotating_components"]
# Create iterator, which yields out info after each optimization step.
training_step_iterator = _distill_loop(
teacher,
student,
optimizer,
create_distill_batches(student, distill_corpus, batcher, D["max_epochs"]),
create_evaluation_callback(student, dev_corpus, score_weights),
dropout=D["dropout"],
accumulate_gradient=T["accumulate_gradient"],
max_steps=D["max_steps"],
eval_frequency=T["eval_frequency"],
exclude=frozen_components,
annotating_components=annotating_components,
before_update=before_update,
student_to_teacher=student_to_teacher,
)
clean_output_dir(output_path)
stdout.write(msg.info(f"Teacher pipeline: {teacher.pipe_names}") + "\n")
stdout.write(msg.info(f"Student pipeline: {student.pipe_names}") + "\n")
if frozen_components:
stdout.write(msg.info(f"Frozen components: {frozen_components}") + "\n")
if annotating_components:
stdout.write(
msg.info(f"Set annotations on update for: {annotating_components}") + "\n"
)
stdout.write(msg.info(f"Initial learn rate: {optimizer.learn_rate(step=0)}") + "\n")
with student.select_pipes(disable=frozen_components):
log_step, finalize_logger = train_logger(student, stdout, stderr)
try:
for batch, info, is_best_checkpoint in training_step_iterator:
if is_best_checkpoint is not None:
with student.select_pipes(disable=frozen_components):
update_meta(T, student, info)
if output_path is not None:
save_checkpoint(is_best_checkpoint)
info["output_path"] = str(output_path / DIR_MODEL_LAST)
log_step(info if is_best_checkpoint is not None else None)
except Exception as e:
if output_path is not None:
stdout.write(
msg.warn(
f"Aborting and saving the final best model. "
f"Encountered exception: {repr(e)}"
)
+ "\n"
)
raise e
finally:
finalize_logger()
if output_path is not None:
save_checkpoint(False)
# This will only run if we did't hit an error
if optimizer.averages:
student.use_params(optimizer.averages)
if output_path is not None:
stdout.write(
msg.good("Saved pipeline to output directory", output_path / DIR_MODEL_LAST)
+ "\n"
)
return (student, output_path / DIR_MODEL_LAST)
else:
return (student, None)
def train(
nlp: "Language",
output_path: Optional[Path] = None,
@ -139,6 +266,126 @@ def train(
return (nlp, None)
def _distill_loop(
teacher: "Language",
student: "Language",
optimizer: Optimizer,
distill_data: Iterable[List[Example]],
evaluate: Callable[[], Tuple[float, Dict[str, float]]],
*,
dropout: float,
eval_frequency: int,
accumulate_gradient: int,
max_steps: int,
exclude: List[str],
annotating_components: List[str],
before_update: Optional[Callable[["Language", Dict[str, Any]], None]],
student_to_teacher: Dict[str, str],
):
"""Train until an evaluation stops improving. Works as a generator,
with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`,
where info is a dict, and is_best_checkpoint is in [True, False, None] --
None indicating that the iteration was not evaluated as a checkpoint.
The evaluation is conducted by calling the evaluate callback.
Positional arguments:
teacher (Language): The teacher pipeline to distill from.
student (Language): The student pipeline to distill into.
optimizer: The optimizer callable.
distill_data (Iterable[Batch]): A generator of batches, with the distillation
data. Each batch should be a Sized[Tuple[Input, Annot]]. The distillation
data iterable needs to take care of iterating over the epochs and
shuffling.
evaluate (Callable[[], Tuple[float, Any]]): A callback to perform evaluation.
The callback should take no arguments and return a tuple
`(main_score, other_scores)`. The main_score should be a float where
higher is better. other_scores can be any object.
Every iteration, the function yields out a tuple with:
* batch: A list of Example objects.
* info: A dict with various information about the last update (see below).
* is_best_checkpoint: A value in None, False, True, indicating whether this
was the best evaluation so far. You should use this to save the model
checkpoints during training. If None, evaluation was not conducted on
that iteration. False means evaluation was conducted, but a previous
evaluation was better.
The info dict provides the following information:
epoch (int): How many passes over the data have been completed.
step (int): How many steps have been completed.
score (float): The main score from the last evaluation.
other_scores: : The other scores from the last evaluation.
losses: The accumulated losses throughout training.
checkpoints: A list of previous results, where each result is a
(score, step, epoch) tuple.
"""
if isinstance(dropout, float):
dropouts = constant(dropout)
else:
dropouts = dropout
results = []
losses: Dict[str, float] = {}
words_seen = 0
start_time = timer()
for step, (epoch, batch) in enumerate(distill_data):
if before_update:
before_update_args = {"step": step, "epoch": epoch}
before_update(student, before_update_args)
dropout = dropouts(optimizer.step)
for subbatch in subdivide_batch(batch, accumulate_gradient):
student.distill(
teacher,
subbatch,
drop=dropout,
losses=losses,
sgd=None,
exclude=exclude,
annotates=annotating_components,
student_to_teacher=student_to_teacher,
)
# TODO: refactor this so we don't have to run it separately in here
for name, proc in student.pipeline:
if (
name not in exclude
and hasattr(proc, "is_trainable")
and proc.is_trainable
and proc.model not in (True, False, None) # type: ignore[attr-defined]
):
proc.finish_update(optimizer) # type: ignore[attr-defined]
optimizer.step_schedules()
if not (step % eval_frequency):
if optimizer.averages:
with student.use_params(optimizer.averages):
score, other_scores = evaluate()
else:
score, other_scores = evaluate()
optimizer.last_score = score # type: ignore[assignment]
results.append((score, step))
is_best_checkpoint = score == max(results)[0]
else:
score, other_scores = (None, None)
is_best_checkpoint = None
words_seen += sum(len(eg) for eg in batch)
info = {
"epoch": epoch,
"step": step,
"score": score,
"other_scores": other_scores,
"losses": losses,
"checkpoints": results,
"seconds": int(timer() - start_time),
"words": words_seen,
}
yield batch, info, is_best_checkpoint
if is_best_checkpoint is not None:
losses = {}
# Stop if we've exhausted our max steps (if specified)
if max_steps and step >= max_steps:
break
def train_while_improving(
nlp: "Language",
optimizer: Optimizer,
@ -264,7 +511,10 @@ def train_while_improving(
def subdivide_batch(batch, accumulate_gradient):
batch = list(batch)
batch.sort(key=lambda eg: len(eg.predicted))
if isinstance(batch, Example):
batch.sort(key=lambda eg: len(eg.predicted))
elif isinstance(batch, Doc):
batch.sort(key=lambda doc: len(doc))
sub_len = len(batch) // accumulate_gradient
start = 0
for i in range(accumulate_gradient):
@ -309,6 +559,20 @@ def create_evaluation_callback(
return evaluate
def create_distill_batches(
nlp: "Language",
corpus: Callable[["Language"], Iterable[Example]],
batcher: Callable[[Iterable[Example]], Iterable[List[Example]]],
max_epochs: int,
):
epoch = 0
while max_epochs < 1 or epoch != max_epochs:
examples = corpus(nlp)
for batch in batcher(examples):
yield epoch, batch
epoch += 1
def create_train_batches(
nlp: "Language",
corpus: Callable[["Language"], Iterable[Example]],