Update train.py

This commit is contained in:
Matthew Honnibal 2020-06-20 21:49:06 +02:00
parent 49145b9ec1
commit 17efd6bfec

View File

@ -210,7 +210,8 @@ def train(
nlp.resume_training()
else:
msg.info(f"Initializing the nlp pipeline: {nlp.pipe_names}")
nlp.begin_training(lambda: corpus.train_dataset(nlp))
train_examples = list(corpus.train_dataset(nlp, shuffle=False))
nlp.begin_training(lambda: train_examples)
# Update tag map with provided mapping
nlp.vocab.morphology.tag_map.update(tag_map)
@ -280,11 +281,14 @@ def train(
eg.reference = None
eg.predicted = None
except Exception as e:
msg.warn(
f"Aborting and saving the final best model. "
f"Encountered exception: {str(e)}",
exits=1,
)
if output_path is not None:
msg.warn(
f"Aborting and saving the final best model. "
f"Encountered exception: {str(e)}",
exits=1,
)
else:
raise e
finally:
if output_path is not None:
final_model_path = output_path / "model-final"
@ -300,7 +304,6 @@ def create_train_batches(nlp, corpus, cfg):
epochs_todo = cfg.get("max_epochs", 0)
while True:
train_examples = list(corpus.train_dataset(nlp))
if len(train_examples) == 0:
raise ValueError(Errors.E988)
random.shuffle(train_examples)