mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-21 17:41:59 +03:00
Add distillation initialization and loop
This commit is contained in:
parent
5d0f48fe69
commit
9a72ea0b91
111
spacy/tests/training/test_loop.py
Normal file
111
spacy/tests/training/test_loop.py
Normal file
|
@ -0,0 +1,111 @@
|
|||
from typing import Callable, Iterable, Iterator
|
||||
import pytest
|
||||
from spacy import Language
|
||||
from spacy.training import Example
|
||||
from spacy.training.initialize import init_nlp_distill
|
||||
from spacy.training.loop import distill, train
|
||||
from spacy.util import load_model_from_config, registry
|
||||
from thinc.api import Config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_str():
|
||||
return """
|
||||
[nlp]
|
||||
lang = "en"
|
||||
pipeline = ["senter"]
|
||||
disabled = []
|
||||
before_creation = null
|
||||
after_creation = null
|
||||
after_pipeline_creation = null
|
||||
batch_size = 1000
|
||||
tokenizer = {"@tokenizers":"spacy.Tokenizer.v1"}
|
||||
|
||||
[components]
|
||||
|
||||
[components.senter]
|
||||
factory = "senter"
|
||||
|
||||
[training]
|
||||
dev_corpus = "corpora.dev"
|
||||
train_corpus = "corpora.train"
|
||||
max_steps = 50
|
||||
seed = 1
|
||||
gpu_allocator = null
|
||||
|
||||
[distillation]
|
||||
corpus = "corpora.train"
|
||||
dropout = 0.1
|
||||
max_epochs = 0
|
||||
max_steps = 50
|
||||
student_to_teacher = {}
|
||||
|
||||
[distillation.batcher]
|
||||
@batchers = "spacy.batch_by_words.v1"
|
||||
size = 3000
|
||||
discard_oversize = false
|
||||
tolerance = 0.2
|
||||
|
||||
[distillation.optimizer]
|
||||
@optimizers = "Adam.v1"
|
||||
beta1 = 0.9
|
||||
beta2 = 0.999
|
||||
L2_is_weight_decay = true
|
||||
L2 = 0.01
|
||||
grad_clip = 1.0
|
||||
use_averages = true
|
||||
eps = 1e-8
|
||||
learn_rate = 1e-4
|
||||
|
||||
[corpora]
|
||||
|
||||
[corpora.dev]
|
||||
@readers = "sentence_corpus"
|
||||
|
||||
[corpora.train]
|
||||
@readers = "sentence_corpus"
|
||||
"""
|
||||
|
||||
|
||||
SENT_STARTS = [0] * 14
|
||||
SENT_STARTS[0] = 1
|
||||
SENT_STARTS[5] = 1
|
||||
SENT_STARTS[9] = 1
|
||||
|
||||
TRAIN_DATA = [
|
||||
(
|
||||
"I like green eggs. Eat blue ham. I like purple eggs.",
|
||||
{"sent_starts": SENT_STARTS},
|
||||
),
|
||||
(
|
||||
"She likes purple eggs. They hate ham. You like yellow eggs.",
|
||||
{"sent_starts": SENT_STARTS},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_distill_loop(config_str):
|
||||
@registry.readers("sentence_corpus")
|
||||
def create_sentence_corpus() -> Callable[[Language], Iterable[Example]]:
|
||||
return SentenceCorpus()
|
||||
|
||||
class SentenceCorpus:
|
||||
def __call__(self, nlp: Language) -> Iterator[Example]:
|
||||
for t in TRAIN_DATA:
|
||||
yield Example.from_dict(nlp.make_doc(t[0]), t[1])
|
||||
|
||||
orig_config = Config().from_str(config_str)
|
||||
teacher = load_model_from_config(orig_config, auto_fill=True, validate=True)
|
||||
teacher.initialize()
|
||||
train(teacher)
|
||||
|
||||
orig_config = Config().from_str(config_str)
|
||||
student = init_nlp_distill(orig_config, teacher)
|
||||
student.initialize()
|
||||
distill(teacher, student)
|
||||
|
||||
doc = student(TRAIN_DATA[0][0])
|
||||
assert doc.sents[0].text == "I like green eggs."
|
||||
assert doc.sents[1].text == "Eat blue ham."
|
||||
assert doc.sents[2].text == "I like purple eggs."
|
|
@ -15,7 +15,7 @@ from .pretrain import get_tok2vec_ref
|
|||
from ..lookups import Lookups
|
||||
from ..vectors import Vectors, Mode as VectorsMode
|
||||
from ..errors import Errors, Warnings
|
||||
from ..schemas import ConfigSchemaTraining
|
||||
from ..schemas import ConfigSchemaDistill, ConfigSchemaTraining
|
||||
from ..util import registry, load_model_from_config, resolve_dot_names, logger
|
||||
from ..util import load_model, ensure_path, get_sourced_components
|
||||
from ..util import OOV_RANK, DEFAULT_OOV_PROB
|
||||
|
@ -27,15 +27,8 @@ if TYPE_CHECKING:
|
|||
def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
||||
raw_config = config
|
||||
config = raw_config.interpolate()
|
||||
if "seed" not in config["training"]:
|
||||
raise ValueError(Errors.E1015.format(value="[training] seed"))
|
||||
if "gpu_allocator" not in config["training"]:
|
||||
raise ValueError(Errors.E1015.format(value="[training] gpu_allocator"))
|
||||
if config["training"]["seed"] is not None:
|
||||
fix_random_seed(config["training"]["seed"])
|
||||
allocator = config["training"]["gpu_allocator"]
|
||||
if use_gpu >= 0 and allocator:
|
||||
set_gpu_allocator(allocator)
|
||||
_set_seed_from_config(config)
|
||||
_set_gpu_allocator_from_config(config, use_gpu)
|
||||
# Use original config here before it's resolved to functions
|
||||
sourced = get_sourced_components(config)
|
||||
nlp = load_model_from_config(raw_config, auto_fill=True)
|
||||
|
@ -101,6 +94,110 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
|||
return nlp
|
||||
|
||||
|
||||
def _set_gpu_allocator_from_config(config, use_gpu):
|
||||
if "gpu_allocator" not in config["training"]:
|
||||
raise ValueError(Errors.E1015.format(value="[training] gpu_allocator"))
|
||||
allocator = config["training"]["gpu_allocator"]
|
||||
if use_gpu >= 0 and allocator:
|
||||
set_gpu_allocator(allocator)
|
||||
|
||||
|
||||
def _set_seed_from_config(config):
|
||||
if "seed" not in config["training"]:
|
||||
raise ValueError(Errors.E1015.format(value="[training] seed"))
|
||||
if config["training"]["seed"] is not None:
|
||||
fix_random_seed(config["training"]["seed"])
|
||||
|
||||
|
||||
def init_nlp_distill(
|
||||
config: Config, teacher: "Language", *, use_gpu: int = -1
|
||||
) -> "Language":
|
||||
raw_config = config
|
||||
config = raw_config.interpolate()
|
||||
_set_seed_from_config(config)
|
||||
_set_gpu_allocator_from_config(config, use_gpu)
|
||||
|
||||
# Use original config here before it's resolved to functions
|
||||
sourced = get_sourced_components(config)
|
||||
nlp = load_model_from_config(raw_config, auto_fill=True)
|
||||
logger.info("Set up nlp object from config")
|
||||
config = nlp.config.interpolate()
|
||||
# Resolve all training-relevant sections using the filled nlp config
|
||||
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
||||
D = registry.resolve(config["distill"], schema=ConfigSchemaDistill)
|
||||
dot_names = [D["corpus"], T["dev_corpus"]]
|
||||
if not isinstance(D["corpus"], str):
|
||||
raise ConfigValidationError(
|
||||
desc=Errors.E897.format(
|
||||
field="distill.corpus", type=type(D["corpus"])
|
||||
)
|
||||
)
|
||||
if not isinstance(T["dev_corpus"], str):
|
||||
raise ConfigValidationError(
|
||||
desc=Errors.E897.format(
|
||||
field="training.dev_corpus", type=type(T["dev_corpus"])
|
||||
)
|
||||
)
|
||||
distill_corpus, dev_corpus = resolve_dot_names(config, dot_names)
|
||||
optimizer = T["optimizer"]
|
||||
# Components that shouldn't be updated during training
|
||||
frozen_components = T["frozen_components"]
|
||||
# Sourced components that require resume_training
|
||||
resume_components = [p for p in sourced if p not in frozen_components]
|
||||
logger.info(f"Pipeline: {nlp.pipe_names}")
|
||||
if resume_components:
|
||||
with nlp.select_pipes(enable=resume_components):
|
||||
logger.info(f"Resuming training for: {resume_components}")
|
||||
nlp.resume_training(sgd=optimizer)
|
||||
# Make sure that listeners are defined before initializing further
|
||||
nlp._link_components()
|
||||
|
||||
# Get teacher labels to initialize student with.
|
||||
student_to_teacher = D["student_to_teacher"]
|
||||
teacher_pipes = dict(teacher.pipeline)
|
||||
labels = {}
|
||||
for name, pipe in nlp.pipeline:
|
||||
# Copy teacher labels.
|
||||
teacher_pipe_name = student_to_teacher[name] if name in student_to_teacher else name
|
||||
teacher_pipe = teacher_pipes.get(teacher_pipe_name, None)
|
||||
if (
|
||||
teacher_pipe is not None
|
||||
and getattr(teacher_pipe, "label_data", None) is not None
|
||||
):
|
||||
labels[name] = teacher_pipe.label_data # type: ignore[attr-defined]
|
||||
|
||||
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
|
||||
# Initialize on the dev corpus, since the distillation corpus does
|
||||
# usually not have labels. Since we copy the labels from the teacher
|
||||
# pipe, the dev data does not have to be exhaustive.
|
||||
if T["max_epochs"] == -1:
|
||||
sample_size = 100
|
||||
logger.debug(
|
||||
f"Due to streamed train corpus, using only first {sample_size} "
|
||||
f"examples for initialization. If necessary, provide all labels "
|
||||
f"in [initialize]. More info: https://spacy.io/api/cli#init_labels"
|
||||
)
|
||||
nlp.initialize(lambda: islice(dev_corpus(nlp), sample_size), sgd=optimizer)
|
||||
else:
|
||||
nlp.initialize(lambda: dev_corpus(nlp), sgd=optimizer, labels=labels)
|
||||
logger.info(f"Initialized pipeline components: {nlp.pipe_names}")
|
||||
# Detect components with listeners that are not frozen consistently
|
||||
for name, proc in nlp.pipeline:
|
||||
for listener in getattr(
|
||||
proc, "listening_components", []
|
||||
): # e.g. tok2vec/transformer
|
||||
# Don't warn about components not in the pipeline
|
||||
if listener not in nlp.pipe_names:
|
||||
continue
|
||||
if listener in frozen_components and name not in frozen_components:
|
||||
logger.warning(Warnings.W087.format(name=name, listener=listener))
|
||||
# We always check this regardless, in case user freezes tok2vec
|
||||
if listener not in frozen_components and name in frozen_components:
|
||||
if name not in T["annotating_components"]:
|
||||
logger.warning(Warnings.W086.format(name=name, listener=listener))
|
||||
return nlp
|
||||
|
||||
|
||||
def init_vocab(
|
||||
nlp: "Language",
|
||||
*,
|
||||
|
|
|
@ -9,8 +9,9 @@ import sys
|
|||
import shutil
|
||||
|
||||
from .example import Example
|
||||
from ..schemas import ConfigSchemaTraining
|
||||
from ..schemas import ConfigSchemaDistill, ConfigSchemaTraining
|
||||
from ..errors import Errors
|
||||
from ..tokens.doc import Doc
|
||||
from ..util import resolve_dot_names, registry, logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -21,6 +22,132 @@ DIR_MODEL_BEST = "model-best"
|
|||
DIR_MODEL_LAST = "model-last"
|
||||
|
||||
|
||||
def distill(
|
||||
teacher: "Language",
|
||||
student: "Language",
|
||||
output_path: Optional[Path] = None,
|
||||
*,
|
||||
use_gpu: int = -1,
|
||||
stdout: IO = sys.stdout,
|
||||
stderr: IO = sys.stderr,
|
||||
) -> Tuple["Language", Optional[Path]]:
|
||||
"""Distill a student pipeline from a teacher pipeline.
|
||||
|
||||
teacher (Language): The teacher pipeline to distill from.
|
||||
student (Language): The student pipeline to distill into.
|
||||
output_path (Optional[Path]): Optional output path to save the student
|
||||
model to.
|
||||
use_gpu (int): Whether to train on GPU. Make sure to call require_gpu
|
||||
before calling this function.
|
||||
stdout (file): A file-like object to write output messages. To disable
|
||||
printing, set to io.StringIO.
|
||||
stderr (file): A second file-like object to write output messages. To disable
|
||||
printing, set to io.StringIO.
|
||||
|
||||
RETURNS (tuple): The final student nlp object and the path to the exported
|
||||
student model.
|
||||
"""
|
||||
# We use no_print here so we can respect the stdout/stderr options.
|
||||
msg = Printer(no_print=True)
|
||||
# Create iterator, which yields out info after each optimization step.
|
||||
config = student.config.interpolate()
|
||||
if config["training"]["seed"] is not None:
|
||||
fix_random_seed(config["training"]["seed"])
|
||||
allocator = config["training"]["gpu_allocator"]
|
||||
if use_gpu >= 0 and allocator:
|
||||
set_gpu_allocator(allocator)
|
||||
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
||||
D = registry.resolve(config["distill"], schema=ConfigSchemaDistill)
|
||||
dot_names = [D["corpus"], T["dev_corpus"]]
|
||||
distill_corpus, dev_corpus = resolve_dot_names(config, dot_names)
|
||||
optimizer = D["optimizer"]
|
||||
score_weights = T["score_weights"]
|
||||
batcher = D["batcher"]
|
||||
train_logger = T["logger"]
|
||||
before_to_disk = create_before_to_disk_callback(T["before_to_disk"])
|
||||
before_update = T["before_update"]
|
||||
student_to_teacher = D["student_to_teacher"]
|
||||
|
||||
# Helper function to save checkpoints. This is a closure for convenience,
|
||||
# to avoid passing in all the args all the time.
|
||||
def save_checkpoint(is_best):
|
||||
with student.use_params(optimizer.averages):
|
||||
before_to_disk(student).to_disk(output_path / DIR_MODEL_LAST)
|
||||
if is_best:
|
||||
# Avoid saving twice (saving will be more expensive than
|
||||
# the dir copy)
|
||||
if (output_path / DIR_MODEL_BEST).exists():
|
||||
shutil.rmtree(output_path / DIR_MODEL_BEST)
|
||||
shutil.copytree(output_path / DIR_MODEL_LAST, output_path / DIR_MODEL_BEST)
|
||||
|
||||
# Components that shouldn't be updated during training
|
||||
frozen_components = T["frozen_components"]
|
||||
# Components that should set annotations on update
|
||||
annotating_components = T["annotating_components"]
|
||||
# Create iterator, which yields out info after each optimization step.
|
||||
training_step_iterator = _distill_loop(
|
||||
teacher,
|
||||
student,
|
||||
optimizer,
|
||||
create_distill_batches(student, distill_corpus, batcher, D["max_epochs"]),
|
||||
create_evaluation_callback(student, dev_corpus, score_weights),
|
||||
dropout=D["dropout"],
|
||||
accumulate_gradient=T["accumulate_gradient"],
|
||||
max_steps=D["max_steps"],
|
||||
eval_frequency=T["eval_frequency"],
|
||||
exclude=frozen_components,
|
||||
annotating_components=annotating_components,
|
||||
before_update=before_update,
|
||||
student_to_teacher=student_to_teacher,
|
||||
)
|
||||
clean_output_dir(output_path)
|
||||
stdout.write(msg.info(f"Teacher pipeline: {teacher.pipe_names}") + "\n")
|
||||
stdout.write(msg.info(f"Student pipeline: {student.pipe_names}") + "\n")
|
||||
if frozen_components:
|
||||
stdout.write(msg.info(f"Frozen components: {frozen_components}") + "\n")
|
||||
if annotating_components:
|
||||
stdout.write(
|
||||
msg.info(f"Set annotations on update for: {annotating_components}") + "\n"
|
||||
)
|
||||
stdout.write(msg.info(f"Initial learn rate: {optimizer.learn_rate(step=0)}") + "\n")
|
||||
with student.select_pipes(disable=frozen_components):
|
||||
log_step, finalize_logger = train_logger(student, stdout, stderr)
|
||||
try:
|
||||
for batch, info, is_best_checkpoint in training_step_iterator:
|
||||
if is_best_checkpoint is not None:
|
||||
with student.select_pipes(disable=frozen_components):
|
||||
update_meta(T, student, info)
|
||||
if output_path is not None:
|
||||
save_checkpoint(is_best_checkpoint)
|
||||
info["output_path"] = str(output_path / DIR_MODEL_LAST)
|
||||
log_step(info if is_best_checkpoint is not None else None)
|
||||
except Exception as e:
|
||||
if output_path is not None:
|
||||
stdout.write(
|
||||
msg.warn(
|
||||
f"Aborting and saving the final best model. "
|
||||
f"Encountered exception: {repr(e)}"
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
raise e
|
||||
finally:
|
||||
finalize_logger()
|
||||
if output_path is not None:
|
||||
save_checkpoint(False)
|
||||
# This will only run if we did't hit an error
|
||||
if optimizer.averages:
|
||||
student.use_params(optimizer.averages)
|
||||
if output_path is not None:
|
||||
stdout.write(
|
||||
msg.good("Saved pipeline to output directory", output_path / DIR_MODEL_LAST)
|
||||
+ "\n"
|
||||
)
|
||||
return (student, output_path / DIR_MODEL_LAST)
|
||||
else:
|
||||
return (student, None)
|
||||
|
||||
|
||||
def train(
|
||||
nlp: "Language",
|
||||
output_path: Optional[Path] = None,
|
||||
|
@ -139,6 +266,126 @@ def train(
|
|||
return (nlp, None)
|
||||
|
||||
|
||||
def _distill_loop(
|
||||
teacher: "Language",
|
||||
student: "Language",
|
||||
optimizer: Optimizer,
|
||||
distill_data: Iterable[List[Example]],
|
||||
evaluate: Callable[[], Tuple[float, Dict[str, float]]],
|
||||
*,
|
||||
dropout: float,
|
||||
eval_frequency: int,
|
||||
accumulate_gradient: int,
|
||||
max_steps: int,
|
||||
exclude: List[str],
|
||||
annotating_components: List[str],
|
||||
before_update: Optional[Callable[["Language", Dict[str, Any]], None]],
|
||||
student_to_teacher: Dict[str, str],
|
||||
):
|
||||
"""Train until an evaluation stops improving. Works as a generator,
|
||||
with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`,
|
||||
where info is a dict, and is_best_checkpoint is in [True, False, None] --
|
||||
None indicating that the iteration was not evaluated as a checkpoint.
|
||||
The evaluation is conducted by calling the evaluate callback.
|
||||
|
||||
Positional arguments:
|
||||
teacher (Language): The teacher pipeline to distill from.
|
||||
student (Language): The student pipeline to distill into.
|
||||
optimizer: The optimizer callable.
|
||||
distill_data (Iterable[Batch]): A generator of batches, with the distillation
|
||||
data. Each batch should be a Sized[Tuple[Input, Annot]]. The distillation
|
||||
data iterable needs to take care of iterating over the epochs and
|
||||
shuffling.
|
||||
evaluate (Callable[[], Tuple[float, Any]]): A callback to perform evaluation.
|
||||
The callback should take no arguments and return a tuple
|
||||
`(main_score, other_scores)`. The main_score should be a float where
|
||||
higher is better. other_scores can be any object.
|
||||
|
||||
Every iteration, the function yields out a tuple with:
|
||||
|
||||
* batch: A list of Example objects.
|
||||
* info: A dict with various information about the last update (see below).
|
||||
* is_best_checkpoint: A value in None, False, True, indicating whether this
|
||||
was the best evaluation so far. You should use this to save the model
|
||||
checkpoints during training. If None, evaluation was not conducted on
|
||||
that iteration. False means evaluation was conducted, but a previous
|
||||
evaluation was better.
|
||||
|
||||
The info dict provides the following information:
|
||||
|
||||
epoch (int): How many passes over the data have been completed.
|
||||
step (int): How many steps have been completed.
|
||||
score (float): The main score from the last evaluation.
|
||||
other_scores: : The other scores from the last evaluation.
|
||||
losses: The accumulated losses throughout training.
|
||||
checkpoints: A list of previous results, where each result is a
|
||||
(score, step, epoch) tuple.
|
||||
"""
|
||||
if isinstance(dropout, float):
|
||||
dropouts = constant(dropout)
|
||||
else:
|
||||
dropouts = dropout
|
||||
results = []
|
||||
losses: Dict[str, float] = {}
|
||||
words_seen = 0
|
||||
start_time = timer()
|
||||
for step, (epoch, batch) in enumerate(distill_data):
|
||||
if before_update:
|
||||
before_update_args = {"step": step, "epoch": epoch}
|
||||
before_update(student, before_update_args)
|
||||
dropout = dropouts(optimizer.step)
|
||||
for subbatch in subdivide_batch(batch, accumulate_gradient):
|
||||
student.distill(
|
||||
teacher,
|
||||
subbatch,
|
||||
drop=dropout,
|
||||
losses=losses,
|
||||
sgd=None,
|
||||
exclude=exclude,
|
||||
annotates=annotating_components,
|
||||
student_to_teacher=student_to_teacher,
|
||||
)
|
||||
# TODO: refactor this so we don't have to run it separately in here
|
||||
for name, proc in student.pipeline:
|
||||
if (
|
||||
name not in exclude
|
||||
and hasattr(proc, "is_trainable")
|
||||
and proc.is_trainable
|
||||
and proc.model not in (True, False, None) # type: ignore[attr-defined]
|
||||
):
|
||||
proc.finish_update(optimizer) # type: ignore[attr-defined]
|
||||
optimizer.step_schedules()
|
||||
if not (step % eval_frequency):
|
||||
if optimizer.averages:
|
||||
with student.use_params(optimizer.averages):
|
||||
score, other_scores = evaluate()
|
||||
else:
|
||||
score, other_scores = evaluate()
|
||||
optimizer.last_score = score # type: ignore[assignment]
|
||||
results.append((score, step))
|
||||
is_best_checkpoint = score == max(results)[0]
|
||||
else:
|
||||
score, other_scores = (None, None)
|
||||
is_best_checkpoint = None
|
||||
words_seen += sum(len(eg) for eg in batch)
|
||||
info = {
|
||||
"epoch": epoch,
|
||||
"step": step,
|
||||
"score": score,
|
||||
"other_scores": other_scores,
|
||||
"losses": losses,
|
||||
"checkpoints": results,
|
||||
"seconds": int(timer() - start_time),
|
||||
"words": words_seen,
|
||||
}
|
||||
yield batch, info, is_best_checkpoint
|
||||
if is_best_checkpoint is not None:
|
||||
losses = {}
|
||||
# Stop if we've exhausted our max steps (if specified)
|
||||
if max_steps and step >= max_steps:
|
||||
break
|
||||
|
||||
|
||||
def train_while_improving(
|
||||
nlp: "Language",
|
||||
optimizer: Optimizer,
|
||||
|
@ -264,7 +511,10 @@ def train_while_improving(
|
|||
|
||||
def subdivide_batch(batch, accumulate_gradient):
|
||||
batch = list(batch)
|
||||
batch.sort(key=lambda eg: len(eg.predicted))
|
||||
if isinstance(batch, Example):
|
||||
batch.sort(key=lambda eg: len(eg.predicted))
|
||||
elif isinstance(batch, Doc):
|
||||
batch.sort(key=lambda doc: len(doc))
|
||||
sub_len = len(batch) // accumulate_gradient
|
||||
start = 0
|
||||
for i in range(accumulate_gradient):
|
||||
|
@ -309,6 +559,20 @@ def create_evaluation_callback(
|
|||
return evaluate
|
||||
|
||||
|
||||
def create_distill_batches(
|
||||
nlp: "Language",
|
||||
corpus: Callable[["Language"], Iterable[Example]],
|
||||
batcher: Callable[[Iterable[Example]], Iterable[List[Example]]],
|
||||
max_epochs: int,
|
||||
):
|
||||
epoch = 0
|
||||
while max_epochs < 1 or epoch != max_epochs:
|
||||
examples = corpus(nlp)
|
||||
for batch in batcher(examples):
|
||||
yield epoch, batch
|
||||
epoch += 1
|
||||
|
||||
|
||||
def create_train_batches(
|
||||
nlp: "Language",
|
||||
corpus: Callable[["Language"], Iterable[Example]],
|
||||
|
|
Loading…
Reference in New Issue
Block a user