mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-31 11:46:22 +03:00
minimal-changes
This commit is contained in:
parent
2fc73b42ae
commit
2c73623a6b
|
@ -200,9 +200,13 @@ def train_cli(
|
|||
except ImportError:
|
||||
msg.fail("Need to install ray_spacy to use distributed training!", exits=1)
|
||||
distributed_setup_and_train(use_gpu, num_workers, strategy, train_args)
|
||||
|
||||
else:
|
||||
setup_and_train(use_gpu, train_args)
|
||||
if use_gpu >= 0:
|
||||
msg.info(f"Using GPU: {use_gpu}")
|
||||
util.use_gpu(use_gpu)
|
||||
else:
|
||||
msg.info("Using CPU")
|
||||
train(**train_args)
|
||||
|
||||
def train(
|
||||
config_path,
|
||||
|
@ -213,7 +217,8 @@ def train(
|
|||
weights_data=None,
|
||||
omit_extra_lookups=False,
|
||||
disable_tqdm=False,
|
||||
remote_optimizer=None
|
||||
remote_optimizer=None,
|
||||
randomization_index=0
|
||||
):
|
||||
msg.info(f"Loading config from: {config_path}")
|
||||
# Read the config first without creating objects, to get to the original nlp_config
|
||||
|
@ -327,7 +332,7 @@ def train(
|
|||
)
|
||||
tok2vec.from_bytes(weights_data)
|
||||
|
||||
train_batches = create_train_batches(nlp, corpus, training)
|
||||
train_batches = create_train_batches(nlp, corpus, training, randomization_index)
|
||||
evaluate = create_evaluation_callback(nlp, optimizer, corpus, training)
|
||||
|
||||
# Create iterator, which yields out info after each optimization step.
|
||||
|
@ -382,7 +387,7 @@ def train(
|
|||
msg.good(f"Saved model to output directory {final_model_path}")
|
||||
|
||||
|
||||
def create_train_batches(nlp, corpus, cfg):
|
||||
def create_train_batches(nlp, corpus, cfg, randomization_index):
|
||||
epochs_todo = cfg.get("max_epochs", 0)
|
||||
while True:
|
||||
train_examples = list(
|
||||
|
@ -397,8 +402,9 @@ def create_train_batches(nlp, corpus, cfg):
|
|||
)
|
||||
if len(train_examples) == 0:
|
||||
raise ValueError(Errors.E988)
|
||||
for _ in range(randomization_index):
|
||||
random.random()
|
||||
random.shuffle(train_examples)
|
||||
|
||||
batches = util.minibatch_by_words(
|
||||
train_examples,
|
||||
size=cfg["batch_size"],
|
||||
|
|
Loading…
Reference in New Issue
Block a user