minimal-changes

This commit is contained in:
Richard Liaw 2020-06-22 18:39:41 -07:00
parent 2fc73b42ae
commit 2c73623a6b

View File

@ -200,9 +200,13 @@ def train_cli(
except ImportError: except ImportError:
msg.fail("Need to install ray_spacy to use distributed training!", exits=1) msg.fail("Need to install ray_spacy to use distributed training!", exits=1)
distributed_setup_and_train(use_gpu, num_workers, strategy, train_args) distributed_setup_and_train(use_gpu, num_workers, strategy, train_args)
else: else:
setup_and_train(use_gpu, train_args) if use_gpu >= 0:
msg.info(f"Using GPU: {use_gpu}")
util.use_gpu(use_gpu)
else:
msg.info("Using CPU")
train(**train_args)
def train( def train(
config_path, config_path,
@ -213,7 +217,8 @@ def train(
weights_data=None, weights_data=None,
omit_extra_lookups=False, omit_extra_lookups=False,
disable_tqdm=False, disable_tqdm=False,
remote_optimizer=None remote_optimizer=None,
randomization_index=0
): ):
msg.info(f"Loading config from: {config_path}") msg.info(f"Loading config from: {config_path}")
# Read the config first without creating objects, to get to the original nlp_config # Read the config first without creating objects, to get to the original nlp_config
@ -327,7 +332,7 @@ def train(
) )
tok2vec.from_bytes(weights_data) tok2vec.from_bytes(weights_data)
train_batches = create_train_batches(nlp, corpus, training) train_batches = create_train_batches(nlp, corpus, training, randomization_index)
evaluate = create_evaluation_callback(nlp, optimizer, corpus, training) evaluate = create_evaluation_callback(nlp, optimizer, corpus, training)
# Create iterator, which yields out info after each optimization step. # Create iterator, which yields out info after each optimization step.
@ -382,7 +387,7 @@ def train(
msg.good(f"Saved model to output directory {final_model_path}") msg.good(f"Saved model to output directory {final_model_path}")
def create_train_batches(nlp, corpus, cfg): def create_train_batches(nlp, corpus, cfg, randomization_index):
epochs_todo = cfg.get("max_epochs", 0) epochs_todo = cfg.get("max_epochs", 0)
while True: while True:
train_examples = list( train_examples = list(
@ -397,8 +402,9 @@ def create_train_batches(nlp, corpus, cfg):
) )
if len(train_examples) == 0: if len(train_examples) == 0:
raise ValueError(Errors.E988) raise ValueError(Errors.E988)
for _ in range(randomization_index):
random.random()
random.shuffle(train_examples) random.shuffle(train_examples)
batches = util.minibatch_by_words( batches = util.minibatch_by_words(
train_examples, train_examples,
size=cfg["batch_size"], size=cfg["batch_size"],