mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-07 15:10:34 +03:00
minimal-changes
This commit is contained in:
parent
2fc73b42ae
commit
2c73623a6b
|
@ -200,9 +200,13 @@ def train_cli(
|
||||||
except ImportError:
|
except ImportError:
|
||||||
msg.fail("Need to install ray_spacy to use distributed training!", exits=1)
|
msg.fail("Need to install ray_spacy to use distributed training!", exits=1)
|
||||||
distributed_setup_and_train(use_gpu, num_workers, strategy, train_args)
|
distributed_setup_and_train(use_gpu, num_workers, strategy, train_args)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
setup_and_train(use_gpu, train_args)
|
if use_gpu >= 0:
|
||||||
|
msg.info(f"Using GPU: {use_gpu}")
|
||||||
|
util.use_gpu(use_gpu)
|
||||||
|
else:
|
||||||
|
msg.info("Using CPU")
|
||||||
|
train(**train_args)
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
config_path,
|
config_path,
|
||||||
|
@ -213,7 +217,8 @@ def train(
|
||||||
weights_data=None,
|
weights_data=None,
|
||||||
omit_extra_lookups=False,
|
omit_extra_lookups=False,
|
||||||
disable_tqdm=False,
|
disable_tqdm=False,
|
||||||
remote_optimizer=None
|
remote_optimizer=None,
|
||||||
|
randomization_index=0
|
||||||
):
|
):
|
||||||
msg.info(f"Loading config from: {config_path}")
|
msg.info(f"Loading config from: {config_path}")
|
||||||
# Read the config first without creating objects, to get to the original nlp_config
|
# Read the config first without creating objects, to get to the original nlp_config
|
||||||
|
@ -327,7 +332,7 @@ def train(
|
||||||
)
|
)
|
||||||
tok2vec.from_bytes(weights_data)
|
tok2vec.from_bytes(weights_data)
|
||||||
|
|
||||||
train_batches = create_train_batches(nlp, corpus, training)
|
train_batches = create_train_batches(nlp, corpus, training, randomization_index)
|
||||||
evaluate = create_evaluation_callback(nlp, optimizer, corpus, training)
|
evaluate = create_evaluation_callback(nlp, optimizer, corpus, training)
|
||||||
|
|
||||||
# Create iterator, which yields out info after each optimization step.
|
# Create iterator, which yields out info after each optimization step.
|
||||||
|
@ -382,7 +387,7 @@ def train(
|
||||||
msg.good(f"Saved model to output directory {final_model_path}")
|
msg.good(f"Saved model to output directory {final_model_path}")
|
||||||
|
|
||||||
|
|
||||||
def create_train_batches(nlp, corpus, cfg):
|
def create_train_batches(nlp, corpus, cfg, randomization_index):
|
||||||
epochs_todo = cfg.get("max_epochs", 0)
|
epochs_todo = cfg.get("max_epochs", 0)
|
||||||
while True:
|
while True:
|
||||||
train_examples = list(
|
train_examples = list(
|
||||||
|
@ -397,8 +402,9 @@ def create_train_batches(nlp, corpus, cfg):
|
||||||
)
|
)
|
||||||
if len(train_examples) == 0:
|
if len(train_examples) == 0:
|
||||||
raise ValueError(Errors.E988)
|
raise ValueError(Errors.E988)
|
||||||
|
for _ in range(randomization_index):
|
||||||
|
random.random()
|
||||||
random.shuffle(train_examples)
|
random.shuffle(train_examples)
|
||||||
|
|
||||||
batches = util.minibatch_by_words(
|
batches = util.minibatch_by_words(
|
||||||
train_examples,
|
train_examples,
|
||||||
size=cfg["batch_size"],
|
size=cfg["batch_size"],
|
||||||
|
|
Loading…
Reference in New Issue
Block a user