mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 16:22:29 +03:00
Add distillation initialization and loop
This commit is contained in:
parent
5d0f48fe69
commit
9a72ea0b91
111
spacy/tests/training/test_loop.py
Normal file
111
spacy/tests/training/test_loop.py
Normal file
|
@ -0,0 +1,111 @@
|
||||||
|
from typing import Callable, Iterable, Iterator
|
||||||
|
import pytest
|
||||||
|
from spacy import Language
|
||||||
|
from spacy.training import Example
|
||||||
|
from spacy.training.initialize import init_nlp_distill
|
||||||
|
from spacy.training.loop import distill, train
|
||||||
|
from spacy.util import load_model_from_config, registry
|
||||||
|
from thinc.api import Config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config_str():
|
||||||
|
return """
|
||||||
|
[nlp]
|
||||||
|
lang = "en"
|
||||||
|
pipeline = ["senter"]
|
||||||
|
disabled = []
|
||||||
|
before_creation = null
|
||||||
|
after_creation = null
|
||||||
|
after_pipeline_creation = null
|
||||||
|
batch_size = 1000
|
||||||
|
tokenizer = {"@tokenizers":"spacy.Tokenizer.v1"}
|
||||||
|
|
||||||
|
[components]
|
||||||
|
|
||||||
|
[components.senter]
|
||||||
|
factory = "senter"
|
||||||
|
|
||||||
|
[training]
|
||||||
|
dev_corpus = "corpora.dev"
|
||||||
|
train_corpus = "corpora.train"
|
||||||
|
max_steps = 50
|
||||||
|
seed = 1
|
||||||
|
gpu_allocator = null
|
||||||
|
|
||||||
|
[distillation]
|
||||||
|
corpus = "corpora.train"
|
||||||
|
dropout = 0.1
|
||||||
|
max_epochs = 0
|
||||||
|
max_steps = 50
|
||||||
|
student_to_teacher = {}
|
||||||
|
|
||||||
|
[distillation.batcher]
|
||||||
|
@batchers = "spacy.batch_by_words.v1"
|
||||||
|
size = 3000
|
||||||
|
discard_oversize = false
|
||||||
|
tolerance = 0.2
|
||||||
|
|
||||||
|
[distillation.optimizer]
|
||||||
|
@optimizers = "Adam.v1"
|
||||||
|
beta1 = 0.9
|
||||||
|
beta2 = 0.999
|
||||||
|
L2_is_weight_decay = true
|
||||||
|
L2 = 0.01
|
||||||
|
grad_clip = 1.0
|
||||||
|
use_averages = true
|
||||||
|
eps = 1e-8
|
||||||
|
learn_rate = 1e-4
|
||||||
|
|
||||||
|
[corpora]
|
||||||
|
|
||||||
|
[corpora.dev]
|
||||||
|
@readers = "sentence_corpus"
|
||||||
|
|
||||||
|
[corpora.train]
|
||||||
|
@readers = "sentence_corpus"
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
SENT_STARTS = [0] * 14
|
||||||
|
SENT_STARTS[0] = 1
|
||||||
|
SENT_STARTS[5] = 1
|
||||||
|
SENT_STARTS[9] = 1
|
||||||
|
|
||||||
|
TRAIN_DATA = [
|
||||||
|
(
|
||||||
|
"I like green eggs. Eat blue ham. I like purple eggs.",
|
||||||
|
{"sent_starts": SENT_STARTS},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"She likes purple eggs. They hate ham. You like yellow eggs.",
|
||||||
|
{"sent_starts": SENT_STARTS},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
def test_distill_loop(config_str):
|
||||||
|
@registry.readers("sentence_corpus")
|
||||||
|
def create_sentence_corpus() -> Callable[[Language], Iterable[Example]]:
|
||||||
|
return SentenceCorpus()
|
||||||
|
|
||||||
|
class SentenceCorpus:
|
||||||
|
def __call__(self, nlp: Language) -> Iterator[Example]:
|
||||||
|
for t in TRAIN_DATA:
|
||||||
|
yield Example.from_dict(nlp.make_doc(t[0]), t[1])
|
||||||
|
|
||||||
|
orig_config = Config().from_str(config_str)
|
||||||
|
teacher = load_model_from_config(orig_config, auto_fill=True, validate=True)
|
||||||
|
teacher.initialize()
|
||||||
|
train(teacher)
|
||||||
|
|
||||||
|
orig_config = Config().from_str(config_str)
|
||||||
|
student = init_nlp_distill(orig_config, teacher)
|
||||||
|
student.initialize()
|
||||||
|
distill(teacher, student)
|
||||||
|
|
||||||
|
doc = student(TRAIN_DATA[0][0])
|
||||||
|
assert doc.sents[0].text == "I like green eggs."
|
||||||
|
assert doc.sents[1].text == "Eat blue ham."
|
||||||
|
assert doc.sents[2].text == "I like purple eggs."
|
|
@ -15,7 +15,7 @@ from .pretrain import get_tok2vec_ref
|
||||||
from ..lookups import Lookups
|
from ..lookups import Lookups
|
||||||
from ..vectors import Vectors, Mode as VectorsMode
|
from ..vectors import Vectors, Mode as VectorsMode
|
||||||
from ..errors import Errors, Warnings
|
from ..errors import Errors, Warnings
|
||||||
from ..schemas import ConfigSchemaTraining
|
from ..schemas import ConfigSchemaDistill, ConfigSchemaTraining
|
||||||
from ..util import registry, load_model_from_config, resolve_dot_names, logger
|
from ..util import registry, load_model_from_config, resolve_dot_names, logger
|
||||||
from ..util import load_model, ensure_path, get_sourced_components
|
from ..util import load_model, ensure_path, get_sourced_components
|
||||||
from ..util import OOV_RANK, DEFAULT_OOV_PROB
|
from ..util import OOV_RANK, DEFAULT_OOV_PROB
|
||||||
|
@ -27,15 +27,8 @@ if TYPE_CHECKING:
|
||||||
def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
||||||
raw_config = config
|
raw_config = config
|
||||||
config = raw_config.interpolate()
|
config = raw_config.interpolate()
|
||||||
if "seed" not in config["training"]:
|
_set_seed_from_config(config)
|
||||||
raise ValueError(Errors.E1015.format(value="[training] seed"))
|
_set_gpu_allocator_from_config(config, use_gpu)
|
||||||
if "gpu_allocator" not in config["training"]:
|
|
||||||
raise ValueError(Errors.E1015.format(value="[training] gpu_allocator"))
|
|
||||||
if config["training"]["seed"] is not None:
|
|
||||||
fix_random_seed(config["training"]["seed"])
|
|
||||||
allocator = config["training"]["gpu_allocator"]
|
|
||||||
if use_gpu >= 0 and allocator:
|
|
||||||
set_gpu_allocator(allocator)
|
|
||||||
# Use original config here before it's resolved to functions
|
# Use original config here before it's resolved to functions
|
||||||
sourced = get_sourced_components(config)
|
sourced = get_sourced_components(config)
|
||||||
nlp = load_model_from_config(raw_config, auto_fill=True)
|
nlp = load_model_from_config(raw_config, auto_fill=True)
|
||||||
|
@ -101,6 +94,110 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
||||||
return nlp
|
return nlp
|
||||||
|
|
||||||
|
|
||||||
|
def _set_gpu_allocator_from_config(config, use_gpu):
|
||||||
|
if "gpu_allocator" not in config["training"]:
|
||||||
|
raise ValueError(Errors.E1015.format(value="[training] gpu_allocator"))
|
||||||
|
allocator = config["training"]["gpu_allocator"]
|
||||||
|
if use_gpu >= 0 and allocator:
|
||||||
|
set_gpu_allocator(allocator)
|
||||||
|
|
||||||
|
|
||||||
|
def _set_seed_from_config(config):
|
||||||
|
if "seed" not in config["training"]:
|
||||||
|
raise ValueError(Errors.E1015.format(value="[training] seed"))
|
||||||
|
if config["training"]["seed"] is not None:
|
||||||
|
fix_random_seed(config["training"]["seed"])
|
||||||
|
|
||||||
|
|
||||||
|
def init_nlp_distill(
|
||||||
|
config: Config, teacher: "Language", *, use_gpu: int = -1
|
||||||
|
) -> "Language":
|
||||||
|
raw_config = config
|
||||||
|
config = raw_config.interpolate()
|
||||||
|
_set_seed_from_config(config)
|
||||||
|
_set_gpu_allocator_from_config(config, use_gpu)
|
||||||
|
|
||||||
|
# Use original config here before it's resolved to functions
|
||||||
|
sourced = get_sourced_components(config)
|
||||||
|
nlp = load_model_from_config(raw_config, auto_fill=True)
|
||||||
|
logger.info("Set up nlp object from config")
|
||||||
|
config = nlp.config.interpolate()
|
||||||
|
# Resolve all training-relevant sections using the filled nlp config
|
||||||
|
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
||||||
|
D = registry.resolve(config["distill"], schema=ConfigSchemaDistill)
|
||||||
|
dot_names = [D["corpus"], T["dev_corpus"]]
|
||||||
|
if not isinstance(D["corpus"], str):
|
||||||
|
raise ConfigValidationError(
|
||||||
|
desc=Errors.E897.format(
|
||||||
|
field="distill.corpus", type=type(D["corpus"])
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not isinstance(T["dev_corpus"], str):
|
||||||
|
raise ConfigValidationError(
|
||||||
|
desc=Errors.E897.format(
|
||||||
|
field="training.dev_corpus", type=type(T["dev_corpus"])
|
||||||
|
)
|
||||||
|
)
|
||||||
|
distill_corpus, dev_corpus = resolve_dot_names(config, dot_names)
|
||||||
|
optimizer = T["optimizer"]
|
||||||
|
# Components that shouldn't be updated during training
|
||||||
|
frozen_components = T["frozen_components"]
|
||||||
|
# Sourced components that require resume_training
|
||||||
|
resume_components = [p for p in sourced if p not in frozen_components]
|
||||||
|
logger.info(f"Pipeline: {nlp.pipe_names}")
|
||||||
|
if resume_components:
|
||||||
|
with nlp.select_pipes(enable=resume_components):
|
||||||
|
logger.info(f"Resuming training for: {resume_components}")
|
||||||
|
nlp.resume_training(sgd=optimizer)
|
||||||
|
# Make sure that listeners are defined before initializing further
|
||||||
|
nlp._link_components()
|
||||||
|
|
||||||
|
# Get teacher labels to initialize student with.
|
||||||
|
student_to_teacher = D["student_to_teacher"]
|
||||||
|
teacher_pipes = dict(teacher.pipeline)
|
||||||
|
labels = {}
|
||||||
|
for name, pipe in nlp.pipeline:
|
||||||
|
# Copy teacher labels.
|
||||||
|
teacher_pipe_name = student_to_teacher[name] if name in student_to_teacher else name
|
||||||
|
teacher_pipe = teacher_pipes.get(teacher_pipe_name, None)
|
||||||
|
if (
|
||||||
|
teacher_pipe is not None
|
||||||
|
and getattr(teacher_pipe, "label_data", None) is not None
|
||||||
|
):
|
||||||
|
labels[name] = teacher_pipe.label_data # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
|
||||||
|
# Initialize on the dev corpus, since the distillation corpus does
|
||||||
|
# usually not have labels. Since we copy the labels from the teacher
|
||||||
|
# pipe, the dev data does not have to be exhaustive.
|
||||||
|
if T["max_epochs"] == -1:
|
||||||
|
sample_size = 100
|
||||||
|
logger.debug(
|
||||||
|
f"Due to streamed train corpus, using only first {sample_size} "
|
||||||
|
f"examples for initialization. If necessary, provide all labels "
|
||||||
|
f"in [initialize]. More info: https://spacy.io/api/cli#init_labels"
|
||||||
|
)
|
||||||
|
nlp.initialize(lambda: islice(dev_corpus(nlp), sample_size), sgd=optimizer)
|
||||||
|
else:
|
||||||
|
nlp.initialize(lambda: dev_corpus(nlp), sgd=optimizer, labels=labels)
|
||||||
|
logger.info(f"Initialized pipeline components: {nlp.pipe_names}")
|
||||||
|
# Detect components with listeners that are not frozen consistently
|
||||||
|
for name, proc in nlp.pipeline:
|
||||||
|
for listener in getattr(
|
||||||
|
proc, "listening_components", []
|
||||||
|
): # e.g. tok2vec/transformer
|
||||||
|
# Don't warn about components not in the pipeline
|
||||||
|
if listener not in nlp.pipe_names:
|
||||||
|
continue
|
||||||
|
if listener in frozen_components and name not in frozen_components:
|
||||||
|
logger.warning(Warnings.W087.format(name=name, listener=listener))
|
||||||
|
# We always check this regardless, in case user freezes tok2vec
|
||||||
|
if listener not in frozen_components and name in frozen_components:
|
||||||
|
if name not in T["annotating_components"]:
|
||||||
|
logger.warning(Warnings.W086.format(name=name, listener=listener))
|
||||||
|
return nlp
|
||||||
|
|
||||||
|
|
||||||
def init_vocab(
|
def init_vocab(
|
||||||
nlp: "Language",
|
nlp: "Language",
|
||||||
*,
|
*,
|
||||||
|
|
|
@ -9,8 +9,9 @@ import sys
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from .example import Example
|
from .example import Example
|
||||||
from ..schemas import ConfigSchemaTraining
|
from ..schemas import ConfigSchemaDistill, ConfigSchemaTraining
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
|
from ..tokens.doc import Doc
|
||||||
from ..util import resolve_dot_names, registry, logger
|
from ..util import resolve_dot_names, registry, logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -21,6 +22,132 @@ DIR_MODEL_BEST = "model-best"
|
||||||
DIR_MODEL_LAST = "model-last"
|
DIR_MODEL_LAST = "model-last"
|
||||||
|
|
||||||
|
|
||||||
|
def distill(
|
||||||
|
teacher: "Language",
|
||||||
|
student: "Language",
|
||||||
|
output_path: Optional[Path] = None,
|
||||||
|
*,
|
||||||
|
use_gpu: int = -1,
|
||||||
|
stdout: IO = sys.stdout,
|
||||||
|
stderr: IO = sys.stderr,
|
||||||
|
) -> Tuple["Language", Optional[Path]]:
|
||||||
|
"""Distill a student pipeline from a teacher pipeline.
|
||||||
|
|
||||||
|
teacher (Language): The teacher pipeline to distill from.
|
||||||
|
student (Language): The student pipeline to distill into.
|
||||||
|
output_path (Optional[Path]): Optional output path to save the student
|
||||||
|
model to.
|
||||||
|
use_gpu (int): Whether to train on GPU. Make sure to call require_gpu
|
||||||
|
before calling this function.
|
||||||
|
stdout (file): A file-like object to write output messages. To disable
|
||||||
|
printing, set to io.StringIO.
|
||||||
|
stderr (file): A second file-like object to write output messages. To disable
|
||||||
|
printing, set to io.StringIO.
|
||||||
|
|
||||||
|
RETURNS (tuple): The final student nlp object and the path to the exported
|
||||||
|
student model.
|
||||||
|
"""
|
||||||
|
# We use no_print here so we can respect the stdout/stderr options.
|
||||||
|
msg = Printer(no_print=True)
|
||||||
|
# Create iterator, which yields out info after each optimization step.
|
||||||
|
config = student.config.interpolate()
|
||||||
|
if config["training"]["seed"] is not None:
|
||||||
|
fix_random_seed(config["training"]["seed"])
|
||||||
|
allocator = config["training"]["gpu_allocator"]
|
||||||
|
if use_gpu >= 0 and allocator:
|
||||||
|
set_gpu_allocator(allocator)
|
||||||
|
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
||||||
|
D = registry.resolve(config["distill"], schema=ConfigSchemaDistill)
|
||||||
|
dot_names = [D["corpus"], T["dev_corpus"]]
|
||||||
|
distill_corpus, dev_corpus = resolve_dot_names(config, dot_names)
|
||||||
|
optimizer = D["optimizer"]
|
||||||
|
score_weights = T["score_weights"]
|
||||||
|
batcher = D["batcher"]
|
||||||
|
train_logger = T["logger"]
|
||||||
|
before_to_disk = create_before_to_disk_callback(T["before_to_disk"])
|
||||||
|
before_update = T["before_update"]
|
||||||
|
student_to_teacher = D["student_to_teacher"]
|
||||||
|
|
||||||
|
# Helper function to save checkpoints. This is a closure for convenience,
|
||||||
|
# to avoid passing in all the args all the time.
|
||||||
|
def save_checkpoint(is_best):
|
||||||
|
with student.use_params(optimizer.averages):
|
||||||
|
before_to_disk(student).to_disk(output_path / DIR_MODEL_LAST)
|
||||||
|
if is_best:
|
||||||
|
# Avoid saving twice (saving will be more expensive than
|
||||||
|
# the dir copy)
|
||||||
|
if (output_path / DIR_MODEL_BEST).exists():
|
||||||
|
shutil.rmtree(output_path / DIR_MODEL_BEST)
|
||||||
|
shutil.copytree(output_path / DIR_MODEL_LAST, output_path / DIR_MODEL_BEST)
|
||||||
|
|
||||||
|
# Components that shouldn't be updated during training
|
||||||
|
frozen_components = T["frozen_components"]
|
||||||
|
# Components that should set annotations on update
|
||||||
|
annotating_components = T["annotating_components"]
|
||||||
|
# Create iterator, which yields out info after each optimization step.
|
||||||
|
training_step_iterator = _distill_loop(
|
||||||
|
teacher,
|
||||||
|
student,
|
||||||
|
optimizer,
|
||||||
|
create_distill_batches(student, distill_corpus, batcher, D["max_epochs"]),
|
||||||
|
create_evaluation_callback(student, dev_corpus, score_weights),
|
||||||
|
dropout=D["dropout"],
|
||||||
|
accumulate_gradient=T["accumulate_gradient"],
|
||||||
|
max_steps=D["max_steps"],
|
||||||
|
eval_frequency=T["eval_frequency"],
|
||||||
|
exclude=frozen_components,
|
||||||
|
annotating_components=annotating_components,
|
||||||
|
before_update=before_update,
|
||||||
|
student_to_teacher=student_to_teacher,
|
||||||
|
)
|
||||||
|
clean_output_dir(output_path)
|
||||||
|
stdout.write(msg.info(f"Teacher pipeline: {teacher.pipe_names}") + "\n")
|
||||||
|
stdout.write(msg.info(f"Student pipeline: {student.pipe_names}") + "\n")
|
||||||
|
if frozen_components:
|
||||||
|
stdout.write(msg.info(f"Frozen components: {frozen_components}") + "\n")
|
||||||
|
if annotating_components:
|
||||||
|
stdout.write(
|
||||||
|
msg.info(f"Set annotations on update for: {annotating_components}") + "\n"
|
||||||
|
)
|
||||||
|
stdout.write(msg.info(f"Initial learn rate: {optimizer.learn_rate(step=0)}") + "\n")
|
||||||
|
with student.select_pipes(disable=frozen_components):
|
||||||
|
log_step, finalize_logger = train_logger(student, stdout, stderr)
|
||||||
|
try:
|
||||||
|
for batch, info, is_best_checkpoint in training_step_iterator:
|
||||||
|
if is_best_checkpoint is not None:
|
||||||
|
with student.select_pipes(disable=frozen_components):
|
||||||
|
update_meta(T, student, info)
|
||||||
|
if output_path is not None:
|
||||||
|
save_checkpoint(is_best_checkpoint)
|
||||||
|
info["output_path"] = str(output_path / DIR_MODEL_LAST)
|
||||||
|
log_step(info if is_best_checkpoint is not None else None)
|
||||||
|
except Exception as e:
|
||||||
|
if output_path is not None:
|
||||||
|
stdout.write(
|
||||||
|
msg.warn(
|
||||||
|
f"Aborting and saving the final best model. "
|
||||||
|
f"Encountered exception: {repr(e)}"
|
||||||
|
)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
finalize_logger()
|
||||||
|
if output_path is not None:
|
||||||
|
save_checkpoint(False)
|
||||||
|
# This will only run if we did't hit an error
|
||||||
|
if optimizer.averages:
|
||||||
|
student.use_params(optimizer.averages)
|
||||||
|
if output_path is not None:
|
||||||
|
stdout.write(
|
||||||
|
msg.good("Saved pipeline to output directory", output_path / DIR_MODEL_LAST)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
return (student, output_path / DIR_MODEL_LAST)
|
||||||
|
else:
|
||||||
|
return (student, None)
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
nlp: "Language",
|
nlp: "Language",
|
||||||
output_path: Optional[Path] = None,
|
output_path: Optional[Path] = None,
|
||||||
|
@ -139,6 +266,126 @@ def train(
|
||||||
return (nlp, None)
|
return (nlp, None)
|
||||||
|
|
||||||
|
|
||||||
|
def _distill_loop(
|
||||||
|
teacher: "Language",
|
||||||
|
student: "Language",
|
||||||
|
optimizer: Optimizer,
|
||||||
|
distill_data: Iterable[List[Example]],
|
||||||
|
evaluate: Callable[[], Tuple[float, Dict[str, float]]],
|
||||||
|
*,
|
||||||
|
dropout: float,
|
||||||
|
eval_frequency: int,
|
||||||
|
accumulate_gradient: int,
|
||||||
|
max_steps: int,
|
||||||
|
exclude: List[str],
|
||||||
|
annotating_components: List[str],
|
||||||
|
before_update: Optional[Callable[["Language", Dict[str, Any]], None]],
|
||||||
|
student_to_teacher: Dict[str, str],
|
||||||
|
):
|
||||||
|
"""Train until an evaluation stops improving. Works as a generator,
|
||||||
|
with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`,
|
||||||
|
where info is a dict, and is_best_checkpoint is in [True, False, None] --
|
||||||
|
None indicating that the iteration was not evaluated as a checkpoint.
|
||||||
|
The evaluation is conducted by calling the evaluate callback.
|
||||||
|
|
||||||
|
Positional arguments:
|
||||||
|
teacher (Language): The teacher pipeline to distill from.
|
||||||
|
student (Language): The student pipeline to distill into.
|
||||||
|
optimizer: The optimizer callable.
|
||||||
|
distill_data (Iterable[Batch]): A generator of batches, with the distillation
|
||||||
|
data. Each batch should be a Sized[Tuple[Input, Annot]]. The distillation
|
||||||
|
data iterable needs to take care of iterating over the epochs and
|
||||||
|
shuffling.
|
||||||
|
evaluate (Callable[[], Tuple[float, Any]]): A callback to perform evaluation.
|
||||||
|
The callback should take no arguments and return a tuple
|
||||||
|
`(main_score, other_scores)`. The main_score should be a float where
|
||||||
|
higher is better. other_scores can be any object.
|
||||||
|
|
||||||
|
Every iteration, the function yields out a tuple with:
|
||||||
|
|
||||||
|
* batch: A list of Example objects.
|
||||||
|
* info: A dict with various information about the last update (see below).
|
||||||
|
* is_best_checkpoint: A value in None, False, True, indicating whether this
|
||||||
|
was the best evaluation so far. You should use this to save the model
|
||||||
|
checkpoints during training. If None, evaluation was not conducted on
|
||||||
|
that iteration. False means evaluation was conducted, but a previous
|
||||||
|
evaluation was better.
|
||||||
|
|
||||||
|
The info dict provides the following information:
|
||||||
|
|
||||||
|
epoch (int): How many passes over the data have been completed.
|
||||||
|
step (int): How many steps have been completed.
|
||||||
|
score (float): The main score from the last evaluation.
|
||||||
|
other_scores: : The other scores from the last evaluation.
|
||||||
|
losses: The accumulated losses throughout training.
|
||||||
|
checkpoints: A list of previous results, where each result is a
|
||||||
|
(score, step, epoch) tuple.
|
||||||
|
"""
|
||||||
|
if isinstance(dropout, float):
|
||||||
|
dropouts = constant(dropout)
|
||||||
|
else:
|
||||||
|
dropouts = dropout
|
||||||
|
results = []
|
||||||
|
losses: Dict[str, float] = {}
|
||||||
|
words_seen = 0
|
||||||
|
start_time = timer()
|
||||||
|
for step, (epoch, batch) in enumerate(distill_data):
|
||||||
|
if before_update:
|
||||||
|
before_update_args = {"step": step, "epoch": epoch}
|
||||||
|
before_update(student, before_update_args)
|
||||||
|
dropout = dropouts(optimizer.step)
|
||||||
|
for subbatch in subdivide_batch(batch, accumulate_gradient):
|
||||||
|
student.distill(
|
||||||
|
teacher,
|
||||||
|
subbatch,
|
||||||
|
drop=dropout,
|
||||||
|
losses=losses,
|
||||||
|
sgd=None,
|
||||||
|
exclude=exclude,
|
||||||
|
annotates=annotating_components,
|
||||||
|
student_to_teacher=student_to_teacher,
|
||||||
|
)
|
||||||
|
# TODO: refactor this so we don't have to run it separately in here
|
||||||
|
for name, proc in student.pipeline:
|
||||||
|
if (
|
||||||
|
name not in exclude
|
||||||
|
and hasattr(proc, "is_trainable")
|
||||||
|
and proc.is_trainable
|
||||||
|
and proc.model not in (True, False, None) # type: ignore[attr-defined]
|
||||||
|
):
|
||||||
|
proc.finish_update(optimizer) # type: ignore[attr-defined]
|
||||||
|
optimizer.step_schedules()
|
||||||
|
if not (step % eval_frequency):
|
||||||
|
if optimizer.averages:
|
||||||
|
with student.use_params(optimizer.averages):
|
||||||
|
score, other_scores = evaluate()
|
||||||
|
else:
|
||||||
|
score, other_scores = evaluate()
|
||||||
|
optimizer.last_score = score # type: ignore[assignment]
|
||||||
|
results.append((score, step))
|
||||||
|
is_best_checkpoint = score == max(results)[0]
|
||||||
|
else:
|
||||||
|
score, other_scores = (None, None)
|
||||||
|
is_best_checkpoint = None
|
||||||
|
words_seen += sum(len(eg) for eg in batch)
|
||||||
|
info = {
|
||||||
|
"epoch": epoch,
|
||||||
|
"step": step,
|
||||||
|
"score": score,
|
||||||
|
"other_scores": other_scores,
|
||||||
|
"losses": losses,
|
||||||
|
"checkpoints": results,
|
||||||
|
"seconds": int(timer() - start_time),
|
||||||
|
"words": words_seen,
|
||||||
|
}
|
||||||
|
yield batch, info, is_best_checkpoint
|
||||||
|
if is_best_checkpoint is not None:
|
||||||
|
losses = {}
|
||||||
|
# Stop if we've exhausted our max steps (if specified)
|
||||||
|
if max_steps and step >= max_steps:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
def train_while_improving(
|
def train_while_improving(
|
||||||
nlp: "Language",
|
nlp: "Language",
|
||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
|
@ -264,7 +511,10 @@ def train_while_improving(
|
||||||
|
|
||||||
def subdivide_batch(batch, accumulate_gradient):
|
def subdivide_batch(batch, accumulate_gradient):
|
||||||
batch = list(batch)
|
batch = list(batch)
|
||||||
batch.sort(key=lambda eg: len(eg.predicted))
|
if isinstance(batch, Example):
|
||||||
|
batch.sort(key=lambda eg: len(eg.predicted))
|
||||||
|
elif isinstance(batch, Doc):
|
||||||
|
batch.sort(key=lambda doc: len(doc))
|
||||||
sub_len = len(batch) // accumulate_gradient
|
sub_len = len(batch) // accumulate_gradient
|
||||||
start = 0
|
start = 0
|
||||||
for i in range(accumulate_gradient):
|
for i in range(accumulate_gradient):
|
||||||
|
@ -309,6 +559,20 @@ def create_evaluation_callback(
|
||||||
return evaluate
|
return evaluate
|
||||||
|
|
||||||
|
|
||||||
|
def create_distill_batches(
|
||||||
|
nlp: "Language",
|
||||||
|
corpus: Callable[["Language"], Iterable[Example]],
|
||||||
|
batcher: Callable[[Iterable[Example]], Iterable[List[Example]]],
|
||||||
|
max_epochs: int,
|
||||||
|
):
|
||||||
|
epoch = 0
|
||||||
|
while max_epochs < 1 or epoch != max_epochs:
|
||||||
|
examples = corpus(nlp)
|
||||||
|
for batch in batcher(examples):
|
||||||
|
yield epoch, batch
|
||||||
|
epoch += 1
|
||||||
|
|
||||||
|
|
||||||
def create_train_batches(
|
def create_train_batches(
|
||||||
nlp: "Language",
|
nlp: "Language",
|
||||||
corpus: Callable[["Language"], Iterable[Example]],
|
corpus: Callable[["Language"], Iterable[Example]],
|
||||||
|
|
Loading…
Reference in New Issue
Block a user