mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-14 10:12:22 +03:00
chane naming and add finally block
This commit is contained in:
parent
8d50a8009c
commit
853ef78198
|
@ -60,11 +60,11 @@ def pretrain(
|
||||||
row_settings = {"widths": (3, 10, 10, 6, 4), "aligns": ("r", "r", "r", "r", "r")}
|
row_settings = {"widths": (3, 10, 10, 6, 4), "aligns": ("r", "r", "r", "r", "r")}
|
||||||
msg.row(("#", "# Words", "Total Loss", "Loss", "w/s"), **row_settings)
|
msg.row(("#", "# Words", "Total Loss", "Loss", "w/s"), **row_settings)
|
||||||
|
|
||||||
def _save_model(epoch, is_temp=False, is_latest=False):
|
def _save_model(epoch, is_temp=False, is_last=False):
|
||||||
is_temp_str = ".temp" if is_temp else ""
|
is_temp_str = ".temp" if is_temp else ""
|
||||||
with model.use_params(optimizer.averages):
|
with model.use_params(optimizer.averages):
|
||||||
if is_latest:
|
if is_last:
|
||||||
save_path = output_dir / f"model_latest.bin"
|
save_path = output_dir / f"model_last.bin"
|
||||||
else:
|
else:
|
||||||
save_path = output_dir / f"model{epoch}{is_temp_str}.bin"
|
save_path = output_dir / f"model{epoch}{is_temp_str}.bin"
|
||||||
with (save_path).open("wb") as file_:
|
with (save_path).open("wb") as file_:
|
||||||
|
@ -81,6 +81,7 @@ def pretrain(
|
||||||
# TODO: I think we probably want this to look more like the
|
# TODO: I think we probably want this to look more like the
|
||||||
# 'create_train_batches' function?
|
# 'create_train_batches' function?
|
||||||
for epoch in range(epoch_resume, P["max_epochs"]):
|
for epoch in range(epoch_resume, P["max_epochs"]):
|
||||||
|
try:
|
||||||
for batch_id, batch in enumerate(batcher(corpus(nlp))):
|
for batch_id, batch in enumerate(batcher(corpus(nlp))):
|
||||||
docs = ensure_docs(batch)
|
docs = ensure_docs(batch)
|
||||||
loss = make_update(model, docs, optimizer, objective)
|
loss = make_update(model, docs, optimizer, objective)
|
||||||
|
@ -90,14 +91,14 @@ def pretrain(
|
||||||
if P["n_save_every"] and (batch_id % P["n_save_every"] == 0):
|
if P["n_save_every"] and (batch_id % P["n_save_every"] == 0):
|
||||||
_save_model(epoch, is_temp=True)
|
_save_model(epoch, is_temp=True)
|
||||||
|
|
||||||
if epoch + 1 == P["max_epochs"]:
|
if P["n_save_epoch"]:
|
||||||
_save_model(epoch, is_latest=True)
|
|
||||||
elif P["n_save_epoch"]:
|
|
||||||
if epoch % P["n_save_epoch"] == 0 or epoch == P["max_epochs"] - 1:
|
if epoch % P["n_save_epoch"] == 0 or epoch == P["max_epochs"] - 1:
|
||||||
_save_model(epoch)
|
_save_model(epoch)
|
||||||
else:
|
else:
|
||||||
_save_model(epoch)
|
_save_model(epoch)
|
||||||
tracker.epoch_loss = 0.0
|
tracker.epoch_loss = 0.0
|
||||||
|
finally:
|
||||||
|
_save_model(epoch, is_last=True)
|
||||||
|
|
||||||
|
|
||||||
def ensure_docs(examples_or_docs: Iterable[Union[Doc, Example]]) -> List[Doc]:
|
def ensure_docs(examples_or_docs: Iterable[Union[Doc, Example]]) -> List[Doc]:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user