mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
WIP on resume
This commit is contained in:
parent
3d8388969e
commit
db3815aa24
|
@ -135,9 +135,14 @@ def train(
|
||||||
layer.from_bytes(weights_data)
|
layer.from_bytes(weights_data)
|
||||||
msg.info(f"Loaded pretrained weights into component '{tok2vec_component}'")
|
msg.info(f"Loaded pretrained weights into component '{tok2vec_component}'")
|
||||||
|
|
||||||
# Create iterator, which yields out info after each optimization step.
|
|
||||||
msg.info("Start training")
|
|
||||||
score_weights = T_cfg["score_weights"]
|
score_weights = T_cfg["score_weights"]
|
||||||
|
if resume_training and has_checkpoint(output_path):
|
||||||
|
nlp, optimizer, resumed_from = load_checkpoint(output_path, nlp, optimizer)
|
||||||
|
msg.info(f"Resuming training from step {nr_step}")
|
||||||
|
else:
|
||||||
|
msg.info("Start training")
|
||||||
|
resumed_from = None
|
||||||
|
# Create iterator, which yields out info after each optimization step.
|
||||||
training_step_iterator = train_while_improving(
|
training_step_iterator = train_while_improving(
|
||||||
nlp,
|
nlp,
|
||||||
optimizer,
|
optimizer,
|
||||||
|
@ -150,6 +155,7 @@ def train(
|
||||||
eval_frequency=T_cfg["eval_frequency"],
|
eval_frequency=T_cfg["eval_frequency"],
|
||||||
raw_text=None,
|
raw_text=None,
|
||||||
exclude=frozen_components,
|
exclude=frozen_components,
|
||||||
|
resumed_from=resumed_from
|
||||||
)
|
)
|
||||||
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
|
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
|
||||||
with nlp.select_pipes(disable=frozen_components):
|
with nlp.select_pipes(disable=frozen_components):
|
||||||
|
@ -161,6 +167,7 @@ def train(
|
||||||
for batch, info, is_best_checkpoint in training_step_iterator:
|
for batch, info, is_best_checkpoint in training_step_iterator:
|
||||||
progress.update(1)
|
progress.update(1)
|
||||||
if is_best_checkpoint is not None:
|
if is_best_checkpoint is not None:
|
||||||
|
save_checkpoint(output_path, nlp, optimizer, info)
|
||||||
progress.close()
|
progress.close()
|
||||||
print_row(info)
|
print_row(info)
|
||||||
if is_best_checkpoint and output_path is not None:
|
if is_best_checkpoint and output_path is not None:
|
||||||
|
@ -171,22 +178,10 @@ def train(
|
||||||
nlp.to_disk(output_path / "model-best")
|
nlp.to_disk(output_path / "model-best")
|
||||||
progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False)
|
progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False)
|
||||||
progress.set_description(f"Epoch {info['epoch']}")
|
progress.set_description(f"Epoch {info['epoch']}")
|
||||||
except Exception as e:
|
|
||||||
finalize_logger()
|
|
||||||
if output_path is not None:
|
|
||||||
# We don't want to swallow the traceback if we don't have a
|
|
||||||
# specific error.
|
|
||||||
msg.warn(
|
|
||||||
f"Aborting and saving the final best model. "
|
|
||||||
f"Encountered exception: {str(e)}"
|
|
||||||
)
|
|
||||||
nlp = before_to_disk(nlp)
|
|
||||||
nlp.to_disk(output_path / "model-final")
|
|
||||||
raise e
|
|
||||||
finally:
|
finally:
|
||||||
finalize_logger()
|
finalize_logger()
|
||||||
if output_path is not None:
|
if output_path is not None:
|
||||||
final_model_path = output_path / "model-final"
|
final_model_path = output_path / "model-last"
|
||||||
if optimizer.averages:
|
if optimizer.averages:
|
||||||
with nlp.use_params(optimizer.averages):
|
with nlp.use_params(optimizer.averages):
|
||||||
nlp.to_disk(final_model_path)
|
nlp.to_disk(final_model_path)
|
||||||
|
@ -263,6 +258,7 @@ def train_while_improving(
|
||||||
max_steps: int,
|
max_steps: int,
|
||||||
raw_text: List[Dict[str, str]],
|
raw_text: List[Dict[str, str]],
|
||||||
exclude: List[str],
|
exclude: List[str],
|
||||||
|
resumed_from: Optional[Dict]=None
|
||||||
):
|
):
|
||||||
"""Train until an evaluation stops improving. Works as a generator,
|
"""Train until an evaluation stops improving. Works as a generator,
|
||||||
with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`,
|
with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`,
|
||||||
|
@ -306,8 +302,17 @@ def train_while_improving(
|
||||||
dropouts = thinc.schedules.constant(dropout)
|
dropouts = thinc.schedules.constant(dropout)
|
||||||
else:
|
else:
|
||||||
dropouts = dropout
|
dropouts = dropout
|
||||||
|
if resumed_from:
|
||||||
|
results = resumed_from["results"]
|
||||||
|
losses = resumed_from["losses"]
|
||||||
|
step = resumed_from["step"]
|
||||||
|
prev_seconds = resumed_from["seconds"]
|
||||||
|
else:
|
||||||
results = []
|
results = []
|
||||||
losses = {}
|
losses = {}
|
||||||
|
step = 0
|
||||||
|
words_seen = 0
|
||||||
|
prev_seconds = 0
|
||||||
if raw_text:
|
if raw_text:
|
||||||
random.shuffle(raw_text)
|
random.shuffle(raw_text)
|
||||||
raw_examples = [
|
raw_examples = [
|
||||||
|
@ -315,9 +320,14 @@ def train_while_improving(
|
||||||
]
|
]
|
||||||
raw_batches = util.minibatch(raw_examples, size=8)
|
raw_batches = util.minibatch(raw_examples, size=8)
|
||||||
|
|
||||||
words_seen = 0
|
for _, (epoch, batch) in zip(range(step), train_data):
|
||||||
|
# If we're resuming, allow the generators to advance for the steps we
|
||||||
|
# did before. It's hard to otherwise restore the generator state.
|
||||||
|
dropout = next(dropouts)
|
||||||
|
optimizer.step_schedules()
|
||||||
|
|
||||||
start_time = timer()
|
start_time = timer()
|
||||||
for step, (epoch, batch) in enumerate(train_data):
|
for epoch, batch in train_data:
|
||||||
dropout = next(dropouts)
|
dropout = next(dropouts)
|
||||||
for subbatch in subdivide_batch(batch, accumulate_gradient):
|
for subbatch in subdivide_batch(batch, accumulate_gradient):
|
||||||
|
|
||||||
|
@ -338,7 +348,7 @@ def train_while_improving(
|
||||||
):
|
):
|
||||||
proc.model.finish_update(optimizer)
|
proc.model.finish_update(optimizer)
|
||||||
optimizer.step_schedules()
|
optimizer.step_schedules()
|
||||||
if not (step % eval_frequency):
|
if step % eval_frequency:
|
||||||
if optimizer.averages:
|
if optimizer.averages:
|
||||||
with nlp.use_params(optimizer.averages):
|
with nlp.use_params(optimizer.averages):
|
||||||
score, other_scores = evaluate()
|
score, other_scores = evaluate()
|
||||||
|
@ -346,9 +356,6 @@ def train_while_improving(
|
||||||
score, other_scores = evaluate()
|
score, other_scores = evaluate()
|
||||||
results.append((score, step))
|
results.append((score, step))
|
||||||
is_best_checkpoint = score == max(results)[0]
|
is_best_checkpoint = score == max(results)[0]
|
||||||
else:
|
|
||||||
score, other_scores = (None, None)
|
|
||||||
is_best_checkpoint = None
|
|
||||||
words_seen += sum(len(eg) for eg in batch)
|
words_seen += sum(len(eg) for eg in batch)
|
||||||
info = {
|
info = {
|
||||||
"epoch": epoch,
|
"epoch": epoch,
|
||||||
|
@ -357,10 +364,13 @@ def train_while_improving(
|
||||||
"other_scores": other_scores,
|
"other_scores": other_scores,
|
||||||
"losses": losses,
|
"losses": losses,
|
||||||
"checkpoints": results,
|
"checkpoints": results,
|
||||||
"seconds": int(timer() - start_time),
|
"seconds": int(timer() - start_time) + prev_seconds,
|
||||||
"words": words_seen,
|
"words": words_seen,
|
||||||
}
|
}
|
||||||
yield batch, info, is_best_checkpoint
|
yield batch, info, is_best_checkpoint
|
||||||
|
else:
|
||||||
|
score, other_scores = (None, None)
|
||||||
|
is_best_checkpoint = None
|
||||||
if is_best_checkpoint is not None:
|
if is_best_checkpoint is not None:
|
||||||
losses = {}
|
losses = {}
|
||||||
# Stop if no improvement in `patience` updates (if specified)
|
# Stop if no improvement in `patience` updates (if specified)
|
||||||
|
@ -370,6 +380,7 @@ def train_while_improving(
|
||||||
# Stop if we've exhausted our max steps (if specified)
|
# Stop if we've exhausted our max steps (if specified)
|
||||||
if max_steps and step >= max_steps:
|
if max_steps and step >= max_steps:
|
||||||
break
|
break
|
||||||
|
step += 1
|
||||||
|
|
||||||
|
|
||||||
def subdivide_batch(batch, accumulate_gradient):
|
def subdivide_batch(batch, accumulate_gradient):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user