mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
Add distillation loop (#12542)
* Add distillation initialization and loop * Fix up configuration keys * Add docstring * Type annotations * init_nlp_distill -> init_nlp_student * Do not resolve dot name distill corpus in initialization (Since we don't use it.) * student: do not request use of optimizer in student pipe We apply finish up the updates once in the training loop instead. Also add the necessary logic to `Language.distill` to mirror `Language.update`. * Correctly determine sort key in subdivide_batch * Fix _distill_loop docstring wrt. stopping condition * _distill_loop: fix distill_data docstring Make similar changes in train_while_improving, since it also had incorrect types and missing type annotations. * Move `set_{gpu_allocator,seed}_from_config` to spacy.util * Update Language.update docs for the sgd argument * Type annotation Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com> --------- Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
This commit is contained in:
parent
5d0f48fe69
commit
8a5814bf2c
|
@ -1024,7 +1024,7 @@ class Language:
|
||||||
examples: Iterable[Example],
|
examples: Iterable[Example],
|
||||||
*,
|
*,
|
||||||
drop: float = 0.0,
|
drop: float = 0.0,
|
||||||
sgd: Optional[Optimizer] = None,
|
sgd: Union[Optimizer, None, Literal[False]] = None,
|
||||||
losses: Optional[Dict[str, float]] = None,
|
losses: Optional[Dict[str, float]] = None,
|
||||||
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
|
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||||
exclude: Iterable[str] = SimpleFrozenList(),
|
exclude: Iterable[str] = SimpleFrozenList(),
|
||||||
|
@ -1037,7 +1037,9 @@ class Language:
|
||||||
(teacher) and predicted (student) docs must have the same number of
|
(teacher) and predicted (student) docs must have the same number of
|
||||||
tokens and the same orthography.
|
tokens and the same orthography.
|
||||||
drop (float): The dropout rate.
|
drop (float): The dropout rate.
|
||||||
sgd (Optional[Optimizer]): An optimizer.
|
sgd (Union[Optimizer, None, Literal[False]]): An optimizer. Will
|
||||||
|
be created via create_optimizer if 'None'. No optimizer will
|
||||||
|
be used when set to 'False'.
|
||||||
losses (Optional(Dict[str, float])): Dictionary to update with the loss,
|
losses (Optional(Dict[str, float])): Dictionary to update with the loss,
|
||||||
keyed by component.
|
keyed by component.
|
||||||
component_cfg (Optional[Dict[str, Dict[str, Any]]]): Config parameters
|
component_cfg (Optional[Dict[str, Dict[str, Any]]]): Config parameters
|
||||||
|
@ -1107,11 +1109,23 @@ class Language:
|
||||||
student_proc.distill(
|
student_proc.distill(
|
||||||
teacher_pipe,
|
teacher_pipe,
|
||||||
examples,
|
examples,
|
||||||
sgd=sgd,
|
sgd=None,
|
||||||
losses=losses,
|
losses=losses,
|
||||||
**component_cfg[student_name],
|
**component_cfg[student_name],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Only finish the update after all component updates are done. Some
|
||||||
|
# components may share weights (such as tok2vec) and we only want
|
||||||
|
# to apply weight updates after all gradients are accumulated.
|
||||||
|
for student_name, student_proc in self.pipeline:
|
||||||
|
if (
|
||||||
|
student_name not in exclude
|
||||||
|
and isinstance(student_proc, ty.DistillableComponent)
|
||||||
|
and student_proc.is_distillable
|
||||||
|
and sgd not in (None, False)
|
||||||
|
):
|
||||||
|
student_proc.finish_update(sgd)
|
||||||
|
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
def disable_pipes(self, *names) -> "DisabledPipes":
|
def disable_pipes(self, *names) -> "DisabledPipes":
|
||||||
|
@ -1837,7 +1851,7 @@ class Language:
|
||||||
# using the nlp.config with all defaults.
|
# using the nlp.config with all defaults.
|
||||||
config = util.copy_config(config)
|
config = util.copy_config(config)
|
||||||
orig_pipeline = config.pop("components", {})
|
orig_pipeline = config.pop("components", {})
|
||||||
orig_distill = config.pop("distill", None)
|
orig_distill = config.pop("distillation", None)
|
||||||
orig_pretraining = config.pop("pretraining", None)
|
orig_pretraining = config.pop("pretraining", None)
|
||||||
config["components"] = {}
|
config["components"] = {}
|
||||||
if auto_fill:
|
if auto_fill:
|
||||||
|
@ -1847,8 +1861,8 @@ class Language:
|
||||||
filled["components"] = orig_pipeline
|
filled["components"] = orig_pipeline
|
||||||
config["components"] = orig_pipeline
|
config["components"] = orig_pipeline
|
||||||
if orig_distill is not None:
|
if orig_distill is not None:
|
||||||
filled["distill"] = orig_distill
|
filled["distillation"] = orig_distill
|
||||||
config["distill"] = orig_distill
|
config["distillation"] = orig_distill
|
||||||
if orig_pretraining is not None:
|
if orig_pretraining is not None:
|
||||||
filled["pretraining"] = orig_pretraining
|
filled["pretraining"] = orig_pretraining
|
||||||
config["pretraining"] = orig_pretraining
|
config["pretraining"] = orig_pretraining
|
||||||
|
|
|
@ -462,7 +462,7 @@ CONFIG_SCHEMAS = {
|
||||||
"training": ConfigSchemaTraining,
|
"training": ConfigSchemaTraining,
|
||||||
"pretraining": ConfigSchemaPretrain,
|
"pretraining": ConfigSchemaPretrain,
|
||||||
"initialize": ConfigSchemaInit,
|
"initialize": ConfigSchemaInit,
|
||||||
"distill": ConfigSchemaDistill,
|
"distillation": ConfigSchemaDistill,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
111
spacy/tests/training/test_loop.py
Normal file
111
spacy/tests/training/test_loop.py
Normal file
|
@ -0,0 +1,111 @@
|
||||||
|
from typing import Callable, Iterable, Iterator
|
||||||
|
import pytest
|
||||||
|
from spacy import Language
|
||||||
|
from spacy.training import Example
|
||||||
|
from spacy.training.initialize import init_nlp_student
|
||||||
|
from spacy.training.loop import distill, train
|
||||||
|
from spacy.util import load_model_from_config, registry
|
||||||
|
from thinc.api import Config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config_str():
|
||||||
|
return """
|
||||||
|
[nlp]
|
||||||
|
lang = "en"
|
||||||
|
pipeline = ["senter"]
|
||||||
|
disabled = []
|
||||||
|
before_creation = null
|
||||||
|
after_creation = null
|
||||||
|
after_pipeline_creation = null
|
||||||
|
batch_size = 1000
|
||||||
|
tokenizer = {"@tokenizers":"spacy.Tokenizer.v1"}
|
||||||
|
|
||||||
|
[components]
|
||||||
|
|
||||||
|
[components.senter]
|
||||||
|
factory = "senter"
|
||||||
|
|
||||||
|
[training]
|
||||||
|
dev_corpus = "corpora.dev"
|
||||||
|
train_corpus = "corpora.train"
|
||||||
|
max_steps = 50
|
||||||
|
seed = 1
|
||||||
|
gpu_allocator = null
|
||||||
|
|
||||||
|
[distillation]
|
||||||
|
corpus = "corpora.train"
|
||||||
|
dropout = 0.1
|
||||||
|
max_epochs = 0
|
||||||
|
max_steps = 50
|
||||||
|
student_to_teacher = {}
|
||||||
|
|
||||||
|
[distillation.batcher]
|
||||||
|
@batchers = "spacy.batch_by_words.v1"
|
||||||
|
size = 3000
|
||||||
|
discard_oversize = false
|
||||||
|
tolerance = 0.2
|
||||||
|
|
||||||
|
[distillation.optimizer]
|
||||||
|
@optimizers = "Adam.v1"
|
||||||
|
beta1 = 0.9
|
||||||
|
beta2 = 0.999
|
||||||
|
L2_is_weight_decay = true
|
||||||
|
L2 = 0.01
|
||||||
|
grad_clip = 1.0
|
||||||
|
use_averages = true
|
||||||
|
eps = 1e-8
|
||||||
|
learn_rate = 1e-4
|
||||||
|
|
||||||
|
[corpora]
|
||||||
|
|
||||||
|
[corpora.dev]
|
||||||
|
@readers = "sentence_corpus"
|
||||||
|
|
||||||
|
[corpora.train]
|
||||||
|
@readers = "sentence_corpus"
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
SENT_STARTS = [0] * 14
|
||||||
|
SENT_STARTS[0] = 1
|
||||||
|
SENT_STARTS[5] = 1
|
||||||
|
SENT_STARTS[9] = 1
|
||||||
|
|
||||||
|
TRAIN_DATA = [
|
||||||
|
(
|
||||||
|
"I like green eggs. Eat blue ham. I like purple eggs.",
|
||||||
|
{"sent_starts": SENT_STARTS},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"She likes purple eggs. They hate ham. You like yellow eggs.",
|
||||||
|
{"sent_starts": SENT_STARTS},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
def test_distill_loop(config_str):
|
||||||
|
@registry.readers("sentence_corpus")
|
||||||
|
def create_sentence_corpus() -> Callable[[Language], Iterable[Example]]:
|
||||||
|
return SentenceCorpus()
|
||||||
|
|
||||||
|
class SentenceCorpus:
|
||||||
|
def __call__(self, nlp: Language) -> Iterator[Example]:
|
||||||
|
for t in TRAIN_DATA:
|
||||||
|
yield Example.from_dict(nlp.make_doc(t[0]), t[1])
|
||||||
|
|
||||||
|
orig_config = Config().from_str(config_str)
|
||||||
|
teacher = load_model_from_config(orig_config, auto_fill=True, validate=True)
|
||||||
|
teacher.initialize()
|
||||||
|
train(teacher)
|
||||||
|
|
||||||
|
orig_config = Config().from_str(config_str)
|
||||||
|
student = init_nlp_student(orig_config, teacher)
|
||||||
|
student.initialize()
|
||||||
|
distill(teacher, student)
|
||||||
|
|
||||||
|
doc = student(TRAIN_DATA[0][0])
|
||||||
|
assert doc.sents[0].text == "I like green eggs."
|
||||||
|
assert doc.sents[1].text == "Eat blue ham."
|
||||||
|
assert doc.sents[2].text == "I like purple eggs."
|
|
@ -1,6 +1,5 @@
|
||||||
from typing import Union, Dict, Optional, Any, IO, TYPE_CHECKING
|
from typing import Union, Dict, Optional, Any, IO, TYPE_CHECKING
|
||||||
from thinc.api import Config, fix_random_seed, set_gpu_allocator
|
from thinc.api import Config, ConfigValidationError
|
||||||
from thinc.api import ConfigValidationError
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import srsly
|
import srsly
|
||||||
import numpy
|
import numpy
|
||||||
|
@ -15,10 +14,11 @@ from .pretrain import get_tok2vec_ref
|
||||||
from ..lookups import Lookups
|
from ..lookups import Lookups
|
||||||
from ..vectors import Vectors, Mode as VectorsMode
|
from ..vectors import Vectors, Mode as VectorsMode
|
||||||
from ..errors import Errors, Warnings
|
from ..errors import Errors, Warnings
|
||||||
from ..schemas import ConfigSchemaTraining
|
from ..schemas import ConfigSchemaDistill, ConfigSchemaTraining
|
||||||
from ..util import registry, load_model_from_config, resolve_dot_names, logger
|
from ..util import registry, load_model_from_config, resolve_dot_names, logger
|
||||||
from ..util import load_model, ensure_path, get_sourced_components
|
from ..util import load_model, ensure_path, get_sourced_components
|
||||||
from ..util import OOV_RANK, DEFAULT_OOV_PROB
|
from ..util import OOV_RANK, DEFAULT_OOV_PROB
|
||||||
|
from ..util import set_gpu_allocator_from_config, set_seed_from_config
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..language import Language # noqa: F401
|
from ..language import Language # noqa: F401
|
||||||
|
@ -27,15 +27,8 @@ if TYPE_CHECKING:
|
||||||
def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
||||||
raw_config = config
|
raw_config = config
|
||||||
config = raw_config.interpolate()
|
config = raw_config.interpolate()
|
||||||
if "seed" not in config["training"]:
|
set_seed_from_config(config)
|
||||||
raise ValueError(Errors.E1015.format(value="[training] seed"))
|
set_gpu_allocator_from_config(config, use_gpu)
|
||||||
if "gpu_allocator" not in config["training"]:
|
|
||||||
raise ValueError(Errors.E1015.format(value="[training] gpu_allocator"))
|
|
||||||
if config["training"]["seed"] is not None:
|
|
||||||
fix_random_seed(config["training"]["seed"])
|
|
||||||
allocator = config["training"]["gpu_allocator"]
|
|
||||||
if use_gpu >= 0 and allocator:
|
|
||||||
set_gpu_allocator(allocator)
|
|
||||||
# Use original config here before it's resolved to functions
|
# Use original config here before it's resolved to functions
|
||||||
sourced = get_sourced_components(config)
|
sourced = get_sourced_components(config)
|
||||||
nlp = load_model_from_config(raw_config, auto_fill=True)
|
nlp = load_model_from_config(raw_config, auto_fill=True)
|
||||||
|
@ -101,6 +94,102 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
||||||
return nlp
|
return nlp
|
||||||
|
|
||||||
|
|
||||||
|
def init_nlp_student(
|
||||||
|
config: Config, teacher: "Language", *, use_gpu: int = -1
|
||||||
|
) -> "Language":
|
||||||
|
"""Initialize student pipeline for distillation.
|
||||||
|
|
||||||
|
config (Config): Student model configuration.
|
||||||
|
teacher (Language): The teacher pipeline to distill from.
|
||||||
|
use_gpu (int): Whether to train on GPU. Make sure to call require_gpu
|
||||||
|
before calling this function.
|
||||||
|
"""
|
||||||
|
raw_config = config
|
||||||
|
config = raw_config.interpolate()
|
||||||
|
set_seed_from_config(config)
|
||||||
|
set_gpu_allocator_from_config(config, use_gpu)
|
||||||
|
|
||||||
|
# Use original config here before it's resolved to functions
|
||||||
|
sourced = get_sourced_components(config)
|
||||||
|
nlp = load_model_from_config(raw_config, auto_fill=True)
|
||||||
|
logger.info("Set up nlp object from config")
|
||||||
|
config = nlp.config.interpolate()
|
||||||
|
# Resolve all training-relevant sections using the filled nlp config
|
||||||
|
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
||||||
|
D = registry.resolve(config["distillation"], schema=ConfigSchemaDistill)
|
||||||
|
dot_names = [T["dev_corpus"]]
|
||||||
|
if not isinstance(D["corpus"], str):
|
||||||
|
raise ConfigValidationError(
|
||||||
|
desc=Errors.E897.format(field="distillation.corpus", type=type(D["corpus"]))
|
||||||
|
)
|
||||||
|
if not isinstance(T["dev_corpus"], str):
|
||||||
|
raise ConfigValidationError(
|
||||||
|
desc=Errors.E897.format(
|
||||||
|
field="training.dev_corpus", type=type(T["dev_corpus"])
|
||||||
|
)
|
||||||
|
)
|
||||||
|
(dev_corpus,) = resolve_dot_names(config, dot_names)
|
||||||
|
optimizer = T["optimizer"]
|
||||||
|
# Components that shouldn't be updated during training
|
||||||
|
frozen_components = T["frozen_components"]
|
||||||
|
# Sourced components that require resume_training
|
||||||
|
resume_components = [p for p in sourced if p not in frozen_components]
|
||||||
|
logger.info(f"Pipeline: {nlp.pipe_names}")
|
||||||
|
if resume_components:
|
||||||
|
with nlp.select_pipes(enable=resume_components):
|
||||||
|
logger.info(f"Resuming training for: {resume_components}")
|
||||||
|
nlp.resume_training(sgd=optimizer)
|
||||||
|
# Make sure that listeners are defined before initializing further
|
||||||
|
nlp._link_components()
|
||||||
|
|
||||||
|
# Get teacher labels to initialize student with.
|
||||||
|
student_to_teacher = D["student_to_teacher"]
|
||||||
|
teacher_pipes = dict(teacher.pipeline)
|
||||||
|
labels = {}
|
||||||
|
for name, pipe in nlp.pipeline:
|
||||||
|
# Copy teacher labels.
|
||||||
|
teacher_pipe_name = (
|
||||||
|
student_to_teacher[name] if name in student_to_teacher else name
|
||||||
|
)
|
||||||
|
teacher_pipe = teacher_pipes.get(teacher_pipe_name, None)
|
||||||
|
if (
|
||||||
|
teacher_pipe is not None
|
||||||
|
and getattr(teacher_pipe, "label_data", None) is not None
|
||||||
|
):
|
||||||
|
labels[name] = teacher_pipe.label_data # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
|
||||||
|
# Initialize on the dev corpus, since the distillation corpus does
|
||||||
|
# usually not have labels. Since we copy the labels from the teacher
|
||||||
|
# pipe, the dev data does not have to be exhaustive.
|
||||||
|
if T["max_epochs"] == -1:
|
||||||
|
sample_size = 100
|
||||||
|
logger.debug(
|
||||||
|
f"Due to streamed train corpus, using only first {sample_size} "
|
||||||
|
f"examples for initialization. If necessary, provide all labels "
|
||||||
|
f"in [initialize]. More info: https://spacy.io/api/cli#init_labels"
|
||||||
|
)
|
||||||
|
nlp.initialize(lambda: islice(dev_corpus(nlp), sample_size), sgd=optimizer)
|
||||||
|
else:
|
||||||
|
nlp.initialize(lambda: dev_corpus(nlp), sgd=optimizer, labels=labels)
|
||||||
|
logger.info(f"Initialized pipeline components: {nlp.pipe_names}")
|
||||||
|
# Detect components with listeners that are not frozen consistently
|
||||||
|
for name, proc in nlp.pipeline:
|
||||||
|
for listener in getattr(
|
||||||
|
proc, "listening_components", []
|
||||||
|
): # e.g. tok2vec/transformer
|
||||||
|
# Don't warn about components not in the pipeline
|
||||||
|
if listener not in nlp.pipe_names:
|
||||||
|
continue
|
||||||
|
if listener in frozen_components and name not in frozen_components:
|
||||||
|
logger.warning(Warnings.W087.format(name=name, listener=listener))
|
||||||
|
# We always check this regardless, in case user freezes tok2vec
|
||||||
|
if listener not in frozen_components and name in frozen_components:
|
||||||
|
if name not in T["annotating_components"]:
|
||||||
|
logger.warning(Warnings.W086.format(name=name, listener=listener))
|
||||||
|
return nlp
|
||||||
|
|
||||||
|
|
||||||
def init_vocab(
|
def init_vocab(
|
||||||
nlp: "Language",
|
nlp: "Language",
|
||||||
*,
|
*,
|
||||||
|
|
|
@ -2,16 +2,20 @@ from typing import List, Callable, Tuple, Dict, Iterable, Union, Any, IO
|
||||||
from typing import Optional, TYPE_CHECKING
|
from typing import Optional, TYPE_CHECKING
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
from thinc.api import Optimizer, Config, constant, fix_random_seed, set_gpu_allocator
|
from thinc.api import Optimizer, Config, constant
|
||||||
from wasabi import Printer
|
from wasabi import Printer
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
from .example import Example
|
from .example import Example
|
||||||
from ..schemas import ConfigSchemaTraining
|
from ..schemas import ConfigSchemaDistill, ConfigSchemaTraining
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
|
from ..tokens.doc import Doc
|
||||||
|
from .. import ty
|
||||||
from ..util import resolve_dot_names, registry, logger
|
from ..util import resolve_dot_names, registry, logger
|
||||||
|
from ..util import set_gpu_allocator_from_config, set_seed_from_config
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..language import Language # noqa: F401
|
from ..language import Language # noqa: F401
|
||||||
|
@ -21,6 +25,129 @@ DIR_MODEL_BEST = "model-best"
|
||||||
DIR_MODEL_LAST = "model-last"
|
DIR_MODEL_LAST = "model-last"
|
||||||
|
|
||||||
|
|
||||||
|
def distill(
|
||||||
|
teacher: "Language",
|
||||||
|
student: "Language",
|
||||||
|
output_path: Optional[Path] = None,
|
||||||
|
*,
|
||||||
|
use_gpu: int = -1,
|
||||||
|
stdout: IO = sys.stdout,
|
||||||
|
stderr: IO = sys.stderr,
|
||||||
|
) -> Tuple["Language", Optional[Path]]:
|
||||||
|
"""Distill a student pipeline from a teacher pipeline.
|
||||||
|
|
||||||
|
teacher (Language): The teacher pipeline to distill from.
|
||||||
|
student (Language): The student pipeline to distill into.
|
||||||
|
output_path (Optional[Path]): Optional output path to save the student
|
||||||
|
model to.
|
||||||
|
use_gpu (int): Whether to train on GPU. Make sure to call require_gpu
|
||||||
|
before calling this function.
|
||||||
|
stdout (file): A file-like object to write output messages. To disable
|
||||||
|
printing, set to io.StringIO.
|
||||||
|
stderr (file): A second file-like object to write output messages. To disable
|
||||||
|
printing, set to io.StringIO.
|
||||||
|
|
||||||
|
RETURNS (tuple): The final student nlp object and the path to the exported
|
||||||
|
student model.
|
||||||
|
"""
|
||||||
|
# We use no_print here so we can respect the stdout/stderr options.
|
||||||
|
msg = Printer(no_print=True)
|
||||||
|
# Create iterator, which yields out info after each optimization step.
|
||||||
|
config = student.config.interpolate()
|
||||||
|
set_seed_from_config(config)
|
||||||
|
set_gpu_allocator_from_config(config, use_gpu)
|
||||||
|
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
||||||
|
D = registry.resolve(config["distillation"], schema=ConfigSchemaDistill)
|
||||||
|
dot_names = [D["corpus"], T["dev_corpus"]]
|
||||||
|
distill_corpus, dev_corpus = resolve_dot_names(config, dot_names)
|
||||||
|
optimizer = D["optimizer"]
|
||||||
|
score_weights = T["score_weights"]
|
||||||
|
batcher = D["batcher"]
|
||||||
|
train_logger = T["logger"]
|
||||||
|
before_to_disk = create_before_to_disk_callback(T["before_to_disk"])
|
||||||
|
before_update = T["before_update"]
|
||||||
|
student_to_teacher = D["student_to_teacher"]
|
||||||
|
|
||||||
|
# Helper function to save checkpoints. This is a closure for convenience,
|
||||||
|
# to avoid passing in all the args all the time.
|
||||||
|
def save_checkpoint(is_best):
|
||||||
|
with student.use_params(optimizer.averages):
|
||||||
|
before_to_disk(student).to_disk(output_path / DIR_MODEL_LAST)
|
||||||
|
if is_best:
|
||||||
|
# Avoid saving twice (saving will be more expensive than
|
||||||
|
# the dir copy)
|
||||||
|
if (output_path / DIR_MODEL_BEST).exists():
|
||||||
|
shutil.rmtree(output_path / DIR_MODEL_BEST)
|
||||||
|
shutil.copytree(output_path / DIR_MODEL_LAST, output_path / DIR_MODEL_BEST)
|
||||||
|
|
||||||
|
# Components that shouldn't be updated during training
|
||||||
|
frozen_components = T["frozen_components"]
|
||||||
|
# Components that should set annotations on update
|
||||||
|
annotating_components = T["annotating_components"]
|
||||||
|
# Create iterator, which yields out info after each optimization step.
|
||||||
|
training_step_iterator = _distill_loop(
|
||||||
|
teacher,
|
||||||
|
student,
|
||||||
|
optimizer,
|
||||||
|
create_distill_batches(student, distill_corpus, batcher, D["max_epochs"]),
|
||||||
|
create_evaluation_callback(student, dev_corpus, score_weights),
|
||||||
|
dropout=D["dropout"],
|
||||||
|
accumulate_gradient=T["accumulate_gradient"],
|
||||||
|
max_steps=D["max_steps"],
|
||||||
|
eval_frequency=T["eval_frequency"],
|
||||||
|
exclude=frozen_components,
|
||||||
|
annotating_components=annotating_components,
|
||||||
|
before_update=before_update,
|
||||||
|
student_to_teacher=student_to_teacher,
|
||||||
|
)
|
||||||
|
clean_output_dir(output_path)
|
||||||
|
stdout.write(msg.info(f"Teacher pipeline: {teacher.pipe_names}") + "\n")
|
||||||
|
stdout.write(msg.info(f"Student pipeline: {student.pipe_names}") + "\n")
|
||||||
|
if frozen_components:
|
||||||
|
stdout.write(msg.info(f"Frozen components: {frozen_components}") + "\n")
|
||||||
|
if annotating_components:
|
||||||
|
stdout.write(
|
||||||
|
msg.info(f"Set annotations on update for: {annotating_components}") + "\n"
|
||||||
|
)
|
||||||
|
stdout.write(msg.info(f"Initial learn rate: {optimizer.learn_rate(step=0)}") + "\n")
|
||||||
|
with student.select_pipes(disable=frozen_components):
|
||||||
|
log_step, finalize_logger = train_logger(student, stdout, stderr)
|
||||||
|
try:
|
||||||
|
for batch, info, is_best_checkpoint in training_step_iterator:
|
||||||
|
if is_best_checkpoint is not None:
|
||||||
|
with student.select_pipes(disable=frozen_components):
|
||||||
|
update_meta(T, student, info)
|
||||||
|
if output_path is not None:
|
||||||
|
save_checkpoint(is_best_checkpoint)
|
||||||
|
info["output_path"] = str(output_path / DIR_MODEL_LAST)
|
||||||
|
log_step(info if is_best_checkpoint is not None else None)
|
||||||
|
except Exception as e:
|
||||||
|
if output_path is not None:
|
||||||
|
stdout.write(
|
||||||
|
msg.warn(
|
||||||
|
f"Aborting and saving the final best model. "
|
||||||
|
f"Encountered exception: {repr(e)}"
|
||||||
|
)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
finalize_logger()
|
||||||
|
if output_path is not None:
|
||||||
|
save_checkpoint(False)
|
||||||
|
# This will only run if we did't hit an error
|
||||||
|
if optimizer.averages:
|
||||||
|
student.use_params(optimizer.averages)
|
||||||
|
if output_path is not None:
|
||||||
|
stdout.write(
|
||||||
|
msg.good("Saved pipeline to output directory", output_path / DIR_MODEL_LAST)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
return (student, output_path / DIR_MODEL_LAST)
|
||||||
|
else:
|
||||||
|
return (student, None)
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
nlp: "Language",
|
nlp: "Language",
|
||||||
output_path: Optional[Path] = None,
|
output_path: Optional[Path] = None,
|
||||||
|
@ -46,11 +173,8 @@ def train(
|
||||||
msg = Printer(no_print=True)
|
msg = Printer(no_print=True)
|
||||||
# Create iterator, which yields out info after each optimization step.
|
# Create iterator, which yields out info after each optimization step.
|
||||||
config = nlp.config.interpolate()
|
config = nlp.config.interpolate()
|
||||||
if config["training"]["seed"] is not None:
|
set_seed_from_config(config)
|
||||||
fix_random_seed(config["training"]["seed"])
|
set_gpu_allocator_from_config(config, use_gpu)
|
||||||
allocator = config["training"]["gpu_allocator"]
|
|
||||||
if use_gpu >= 0 and allocator:
|
|
||||||
set_gpu_allocator(allocator)
|
|
||||||
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
||||||
dot_names = [T["train_corpus"], T["dev_corpus"]]
|
dot_names = [T["train_corpus"], T["dev_corpus"]]
|
||||||
train_corpus, dev_corpus = resolve_dot_names(config, dot_names)
|
train_corpus, dev_corpus = resolve_dot_names(config, dot_names)
|
||||||
|
@ -139,11 +263,131 @@ def train(
|
||||||
return (nlp, None)
|
return (nlp, None)
|
||||||
|
|
||||||
|
|
||||||
|
def _distill_loop(
|
||||||
|
teacher: "Language",
|
||||||
|
student: "Language",
|
||||||
|
optimizer: Optimizer,
|
||||||
|
distill_data: Iterable[List[Example]],
|
||||||
|
evaluate: Callable[[], Tuple[float, Dict[str, float]]],
|
||||||
|
*,
|
||||||
|
dropout: float,
|
||||||
|
eval_frequency: int,
|
||||||
|
accumulate_gradient: int,
|
||||||
|
max_steps: int,
|
||||||
|
exclude: List[str],
|
||||||
|
annotating_components: List[str],
|
||||||
|
before_update: Optional[Callable[["Language", Dict[str, Any]], None]],
|
||||||
|
student_to_teacher: Dict[str, str],
|
||||||
|
):
|
||||||
|
"""Distill until the data is exhausted or the maximum number of steps
|
||||||
|
has been reached. Works as a generator, with each iteration yielding
|
||||||
|
a tuple `(batch, info, is_best_checkpoint)`, where info is a dict, and
|
||||||
|
is_best_checkpoint is in [True, False, None] -- None indicating that
|
||||||
|
the iteration was not evaluated as a checkpoint. The evaluation is
|
||||||
|
conducted by calling the evaluate callback.
|
||||||
|
|
||||||
|
Positional arguments:
|
||||||
|
teacher (Language): The teacher pipeline to distill from.
|
||||||
|
student (Language): The student pipeline to distill into.
|
||||||
|
optimizer: The optimizer callable.
|
||||||
|
distill_data (Iterable[List[Example]]): A generator of batches,
|
||||||
|
with the distillation data. The distillation data iterable
|
||||||
|
needs to take care of iterating over the epochs and shuffling.
|
||||||
|
evaluate (Callable[[], Tuple[float, Any]]): A callback to perform evaluation.
|
||||||
|
The callback should take no arguments and return a tuple
|
||||||
|
`(main_score, other_scores)`. The main_score should be a float where
|
||||||
|
higher is better. other_scores can be any object.
|
||||||
|
|
||||||
|
Every iteration, the function yields out a tuple with:
|
||||||
|
|
||||||
|
* batch: A list of Example objects.
|
||||||
|
* info: A dict with various information about the last update (see below).
|
||||||
|
* is_best_checkpoint: A value in None, False, True, indicating whether this
|
||||||
|
was the best evaluation so far. You should use this to save the model
|
||||||
|
checkpoints during training. If None, evaluation was not conducted on
|
||||||
|
that iteration. False means evaluation was conducted, but a previous
|
||||||
|
evaluation was better.
|
||||||
|
|
||||||
|
The info dict provides the following information:
|
||||||
|
|
||||||
|
epoch (int): How many passes over the data have been completed.
|
||||||
|
step (int): How many steps have been completed.
|
||||||
|
score (float): The main score from the last evaluation.
|
||||||
|
other_scores: : The other scores from the last evaluation.
|
||||||
|
losses: The accumulated losses throughout training.
|
||||||
|
checkpoints: A list of previous results, where each result is a
|
||||||
|
(score, step, epoch) tuple.
|
||||||
|
"""
|
||||||
|
if isinstance(dropout, float):
|
||||||
|
dropouts = constant(dropout)
|
||||||
|
else:
|
||||||
|
dropouts = dropout
|
||||||
|
results = []
|
||||||
|
losses: Dict[str, float] = {}
|
||||||
|
words_seen = 0
|
||||||
|
start_time = timer()
|
||||||
|
for step, (epoch, batch) in enumerate(distill_data):
|
||||||
|
if before_update:
|
||||||
|
before_update_args = {"step": step, "epoch": epoch}
|
||||||
|
before_update(student, before_update_args)
|
||||||
|
dropout = dropouts(optimizer.step)
|
||||||
|
for subbatch in subdivide_batch(batch, accumulate_gradient):
|
||||||
|
student.distill(
|
||||||
|
teacher,
|
||||||
|
subbatch,
|
||||||
|
drop=dropout,
|
||||||
|
losses=losses,
|
||||||
|
sgd=False,
|
||||||
|
exclude=exclude,
|
||||||
|
annotates=annotating_components,
|
||||||
|
student_to_teacher=student_to_teacher,
|
||||||
|
)
|
||||||
|
# TODO: refactor this so we don't have to run it separately in here
|
||||||
|
for student_name, student_proc in student.pipeline:
|
||||||
|
if (
|
||||||
|
student_name not in exclude
|
||||||
|
and isinstance(student_proc, ty.DistillableComponent)
|
||||||
|
and student_proc.is_distillable
|
||||||
|
and student_proc.model not in (False, None) # type: ignore[attr-defined]
|
||||||
|
):
|
||||||
|
student_proc.finish_update(optimizer) # type: ignore[attr-defined]
|
||||||
|
optimizer.step_schedules()
|
||||||
|
if not (step % eval_frequency):
|
||||||
|
if optimizer.averages:
|
||||||
|
with student.use_params(optimizer.averages):
|
||||||
|
score, other_scores = evaluate()
|
||||||
|
else:
|
||||||
|
score, other_scores = evaluate()
|
||||||
|
optimizer.last_score = score # type: ignore[assignment]
|
||||||
|
results.append((score, step))
|
||||||
|
is_best_checkpoint = score == max(results)[0]
|
||||||
|
else:
|
||||||
|
score, other_scores = (None, None)
|
||||||
|
is_best_checkpoint = None
|
||||||
|
words_seen += sum(len(eg) for eg in batch)
|
||||||
|
info = {
|
||||||
|
"epoch": epoch,
|
||||||
|
"step": step,
|
||||||
|
"score": score,
|
||||||
|
"other_scores": other_scores,
|
||||||
|
"losses": losses,
|
||||||
|
"checkpoints": results,
|
||||||
|
"seconds": int(timer() - start_time),
|
||||||
|
"words": words_seen,
|
||||||
|
}
|
||||||
|
yield batch, info, is_best_checkpoint
|
||||||
|
if is_best_checkpoint is not None:
|
||||||
|
losses = {}
|
||||||
|
# Stop if we've exhausted our max steps (if specified)
|
||||||
|
if max_steps and step >= max_steps:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
def train_while_improving(
|
def train_while_improving(
|
||||||
nlp: "Language",
|
nlp: "Language",
|
||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
train_data,
|
train_data: Iterable[List[Example]],
|
||||||
evaluate,
|
evaluate: Callable[[], Tuple[float, Dict[str, float]]],
|
||||||
*,
|
*,
|
||||||
dropout: float,
|
dropout: float,
|
||||||
eval_frequency: int,
|
eval_frequency: int,
|
||||||
|
@ -163,10 +407,9 @@ def train_while_improving(
|
||||||
Positional arguments:
|
Positional arguments:
|
||||||
nlp: The spaCy pipeline to evaluate.
|
nlp: The spaCy pipeline to evaluate.
|
||||||
optimizer: The optimizer callable.
|
optimizer: The optimizer callable.
|
||||||
train_data (Iterable[Batch]): A generator of batches, with the training
|
train_data (Iterable[List[Example]]): A generator of batches, with the
|
||||||
data. Each batch should be a Sized[Tuple[Input, Annot]]. The training
|
training data. The training data iterable needs to take care of
|
||||||
data iterable needs to take care of iterating over the epochs and
|
iterating over the epochs and shuffling.
|
||||||
shuffling.
|
|
||||||
evaluate (Callable[[], Tuple[float, Any]]): A callback to perform evaluation.
|
evaluate (Callable[[], Tuple[float, Any]]): A callback to perform evaluation.
|
||||||
The callback should take no arguments and return a tuple
|
The callback should take no arguments and return a tuple
|
||||||
`(main_score, other_scores)`. The main_score should be a float where
|
`(main_score, other_scores)`. The main_score should be a float where
|
||||||
|
@ -230,7 +473,7 @@ def train_while_improving(
|
||||||
score, other_scores = evaluate()
|
score, other_scores = evaluate()
|
||||||
else:
|
else:
|
||||||
score, other_scores = evaluate()
|
score, other_scores = evaluate()
|
||||||
optimizer.last_score = score
|
optimizer.last_score = score # type: ignore[assignment]
|
||||||
results.append((score, step))
|
results.append((score, step))
|
||||||
is_best_checkpoint = score == max(results)[0]
|
is_best_checkpoint = score == max(results)[0]
|
||||||
else:
|
else:
|
||||||
|
@ -262,9 +505,15 @@ def train_while_improving(
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
def subdivide_batch(batch, accumulate_gradient):
|
def subdivide_batch(
|
||||||
|
batch: Union[Iterable[Doc], Iterable[Example]], accumulate_gradient: int
|
||||||
|
):
|
||||||
batch = list(batch)
|
batch = list(batch)
|
||||||
batch.sort(key=lambda eg: len(eg.predicted))
|
if len(batch):
|
||||||
|
if isinstance(batch[0], Example):
|
||||||
|
batch.sort(key=lambda eg: len(eg.predicted))
|
||||||
|
else:
|
||||||
|
batch.sort(key=lambda doc: len(doc))
|
||||||
sub_len = len(batch) // accumulate_gradient
|
sub_len = len(batch) // accumulate_gradient
|
||||||
start = 0
|
start = 0
|
||||||
for i in range(accumulate_gradient):
|
for i in range(accumulate_gradient):
|
||||||
|
@ -309,6 +558,22 @@ def create_evaluation_callback(
|
||||||
return evaluate
|
return evaluate
|
||||||
|
|
||||||
|
|
||||||
|
def create_distill_batches(
|
||||||
|
nlp: "Language",
|
||||||
|
corpus: Callable[["Language"], Iterable[Example]],
|
||||||
|
batcher: Callable[[Iterable[Example]], Iterable[List[Example]]],
|
||||||
|
max_epochs: int,
|
||||||
|
):
|
||||||
|
"""Create distillation batches. In contrast to training, the corpus
|
||||||
|
is normally too large to load into memory and shuffle."""
|
||||||
|
epoch = 0
|
||||||
|
while max_epochs < 1 or epoch != max_epochs:
|
||||||
|
examples = corpus(nlp)
|
||||||
|
for batch in batcher(examples):
|
||||||
|
yield epoch, batch
|
||||||
|
epoch += 1
|
||||||
|
|
||||||
|
|
||||||
def create_train_batches(
|
def create_train_batches(
|
||||||
nlp: "Language",
|
nlp: "Language",
|
||||||
corpus: Callable[["Language"], Iterable[Example]],
|
corpus: Callable[["Language"], Iterable[Example]],
|
||||||
|
|
|
@ -11,6 +11,7 @@ from pathlib import Path
|
||||||
import thinc
|
import thinc
|
||||||
from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer
|
from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer
|
||||||
from thinc.api import ConfigValidationError, Model, constant as constant_schedule
|
from thinc.api import ConfigValidationError, Model, constant as constant_schedule
|
||||||
|
from thinc.api import fix_random_seed, set_gpu_allocator
|
||||||
import functools
|
import functools
|
||||||
import itertools
|
import itertools
|
||||||
import numpy
|
import numpy
|
||||||
|
@ -1790,3 +1791,22 @@ def find_available_port(start: int, host: str, auto_select: bool = False) -> int
|
||||||
# if we get here, the port changed
|
# if we get here, the port changed
|
||||||
warnings.warn(Warnings.W124.format(host=host, port=start, serve_port=port))
|
warnings.warn(Warnings.W124.format(host=host, port=start, serve_port=port))
|
||||||
return port
|
return port
|
||||||
|
|
||||||
|
|
||||||
|
def set_gpu_allocator_from_config(config: Config, use_gpu: int):
|
||||||
|
"""Change the global GPU allocator based to the value in
|
||||||
|
the configuration."""
|
||||||
|
if "gpu_allocator" not in config["training"]:
|
||||||
|
raise ValueError(Errors.E1015.format(value="[training] gpu_allocator"))
|
||||||
|
allocator = config["training"]["gpu_allocator"]
|
||||||
|
if use_gpu >= 0 and allocator:
|
||||||
|
set_gpu_allocator(allocator)
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed_from_config(config: Config):
|
||||||
|
"""Set the random number generator seed to the value in
|
||||||
|
the configuration."""
|
||||||
|
if "seed" not in config["training"]:
|
||||||
|
raise ValueError(Errors.E1015.format(value="[training] seed"))
|
||||||
|
if config["training"]["seed"] is not None:
|
||||||
|
fix_random_seed(config["training"]["seed"])
|
||||||
|
|
|
@ -347,19 +347,19 @@ Distill the models in a student pipeline from a teacher pipeline.
|
||||||
> student.distill(teacher, examples, sgd=optimizer)
|
> student.distill(teacher, examples, sgd=optimizer)
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| -------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| -------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| `teacher` | The teacher pipeline to distill from. ~~Language~~ |
|
| `teacher` | The teacher pipeline to distill from. ~~Language~~ |
|
||||||
| `examples` | A batch of [`Example`](/api/example) distillation examples. The reference (teacher) and predicted (student) docs must have the same number of tokens and orthography. ~~Iterable[Example]~~ |
|
| `examples` | A batch of [`Example`](/api/example) distillation examples. The reference (teacher) and predicted (student) docs must have the same number of tokens and orthography. ~~Iterable[Example]~~ |
|
||||||
| _keyword-only_ | |
|
| _keyword-only_ | |
|
||||||
| `drop` | The dropout rate. ~~float~~ |
|
| `drop` | The dropout rate. ~~float~~ |
|
||||||
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
|
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if `None`. No optimizer will be used when set to `False`. Defaults to `None`. ~~Union[Optimizer, None, Literal[False]]~~ |
|
||||||
| `losses` | Dictionary to update with the loss, keyed by pipeline component. ~~Optional[Dict[str, float]]~~ |
|
| `losses` | Dictionary to update with the loss, keyed by pipeline component. ~~Optional[Dict[str, float]]~~ |
|
||||||
| `component_cfg` | Optional dictionary of keyword arguments for components, keyed by component names. Defaults to `None`. ~~Optional[Dict[str, Dict[str, Any]]]~~ |
|
| `component_cfg` | Optional dictionary of keyword arguments for components, keyed by component names. Defaults to `None`. ~~Optional[Dict[str, Dict[str, Any]]]~~ |
|
||||||
| `exclude` | Names of components that shouldn't be updated. Defaults to `[]`. ~~Iterable[str]~~ |
|
| `exclude` | Names of components that shouldn't be updated. Defaults to `[]`. ~~Iterable[str]~~ |
|
||||||
| `annotates` | Names of components that should set annotations on the prediced examples after updating. Defaults to `[]`. ~~Iterable[str]~~ |
|
| `annotates` | Names of components that should set annotations on the prediced examples after updating. Defaults to `[]`. ~~Iterable[str]~~ |
|
||||||
| `student_to_teacher` | Map student component names to teacher component names, only necessary when the names differ. Defaults to `None`. ~~Optional[Dict[str, str]]~~ |
|
| `student_to_teacher` | Map student component names to teacher component names, only necessary when the names differ. Defaults to `None`. ~~Optional[Dict[str, str]]~~ |
|
||||||
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
|
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
|
||||||
|
|
||||||
## Language.rehearse {id="rehearse",tag="method,experimental",version="3"}
|
## Language.rehearse {id="rehearse",tag="method,experimental",version="3"}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user