mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-15 19:10:34 +03:00
Changes to train for parallel training. Temporary -- dont merge
This commit is contained in:
parent
874b34e5d4
commit
4c5d6b13c8
|
@ -171,6 +171,7 @@ def train_cli(
|
|||
}
|
||||
)
|
||||
else:
|
||||
msg.info(f"Loading config from: {config_path}")
|
||||
nlp, config = load_nlp_and_config(config_path)
|
||||
corpus = Corpus(train_path, dev_path, limit=config["training"]["limit"])
|
||||
|
||||
|
@ -203,6 +204,8 @@ def load_nlp_and_config(config_path):
|
|||
nlp_config = config["nlp"]
|
||||
config = util.load_config(config_path, create_objects=True)
|
||||
nlp = util.load_model_from_config(nlp_config)
|
||||
# TODO: This is hacky, but temporary convenience...
|
||||
config["_nlp_config"] = nlp_config
|
||||
return nlp, config
|
||||
|
||||
|
||||
|
@ -216,12 +219,12 @@ def train(
|
|||
weights_data: Optional[bytes] = None,
|
||||
omit_extra_lookups: bool = False,
|
||||
disable_tqdm: bool = False,
|
||||
randomization_index: int = 0,
|
||||
num_workers=1
|
||||
worker_id: int = 0,
|
||||
num_workers=1,
|
||||
) -> None:
|
||||
msg.info(f"Loading config from: {config_path}")
|
||||
# Read the config first without creating objects, to get to the original nlp_config
|
||||
training = config["training"]
|
||||
nlp_config = config["_nlp_config"]
|
||||
msg.info("Creating nlp from config")
|
||||
optimizer = training["optimizer"]
|
||||
if "textcat" in nlp_config["pipeline"]:
|
||||
|
@ -271,7 +274,7 @@ def train(
|
|||
tok2vec.from_bytes(weights_data)
|
||||
|
||||
msg.info("Loading training corpus")
|
||||
train_batches = create_train_batches(nlp, corpus, training, randomization_index)
|
||||
train_batches = create_train_batches(nlp, corpus, training, worker_id)
|
||||
evaluate = create_evaluation_callback(nlp, optimizer, corpus, training)
|
||||
|
||||
# Create iterator, which yields out info after each optimization step.
|
||||
|
@ -289,7 +292,7 @@ def train(
|
|||
raw_text=raw_text,
|
||||
)
|
||||
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
|
||||
print_row = setup_printer(training, nlp)
|
||||
print_row = setup_printer(training, nlp.pipe_names)
|
||||
tqdm_args = dict(total=training["eval_frequency"], leave=False, disable=disable_tqdm)
|
||||
try:
|
||||
progress = tqdm.tqdm(**tqdm_args)
|
||||
|
@ -476,6 +479,8 @@ def train_while_improving(
|
|||
]
|
||||
raw_batches = util.minibatch(raw_examples, size=8)
|
||||
|
||||
start_time = timer()
|
||||
words_seen = 0
|
||||
for step, (epoch, batch) in enumerate(train_data):
|
||||
dropout = next(dropouts)
|
||||
with nlp.select_pipes(enable=to_enable):
|
||||
|
@ -497,13 +502,16 @@ def train_while_improving(
|
|||
else:
|
||||
score, other_scores = (None, None)
|
||||
is_best_checkpoint = None
|
||||
words_seen += sum(len(eg) for eg in batch)
|
||||
info = {
|
||||
"epoch": epoch,
|
||||
"step": step,
|
||||
"words": words_seen,
|
||||
"score": score,
|
||||
"other_scores": other_scores,
|
||||
"losses": losses,
|
||||
"checkpoints": results,
|
||||
"seconds": int(timer() - start_time)
|
||||
}
|
||||
yield batch, info, is_best_checkpoint
|
||||
if is_best_checkpoint is not None:
|
||||
|
@ -532,14 +540,16 @@ def subdivide_batch(batch, accumulate_gradient):
|
|||
yield subbatch
|
||||
|
||||
|
||||
def setup_printer(training, nlp):
|
||||
def setup_printer(training, pipe_names):
|
||||
score_cols = training["scores"]
|
||||
score_widths = [max(len(col), 6) for col in score_cols]
|
||||
loss_cols = [f"Loss {pipe}" for pipe in nlp.pipe_names]
|
||||
loss_cols = [f"Loss {pipe}" for pipe in pipe_names]
|
||||
loss_widths = [max(len(col), 8) for col in loss_cols]
|
||||
table_header = ["E", "#"] + loss_cols + score_cols + ["Score"]
|
||||
table_header = [col.upper() for col in table_header]
|
||||
table_widths = [3, 6] + loss_widths + score_widths + [6]
|
||||
table_header.append("WPS (TRAIN)")
|
||||
table_widths.append(len(table_header[-1]))
|
||||
table_aligns = ["r" for _ in table_widths]
|
||||
|
||||
msg.row(table_header, widths=table_widths)
|
||||
|
@ -549,7 +559,7 @@ def setup_printer(training, nlp):
|
|||
try:
|
||||
losses = [
|
||||
"{0:.2f}".format(float(info["losses"][pipe_name]))
|
||||
for pipe_name in nlp.pipe_names
|
||||
for pipe_name in pipe_names
|
||||
]
|
||||
except KeyError as e:
|
||||
raise KeyError(
|
||||
|
@ -575,6 +585,7 @@ def setup_printer(training, nlp):
|
|||
+ losses
|
||||
+ scores
|
||||
+ ["{0:.2f}".format(float(info["score"]))]
|
||||
+ ["%d" % (info["words"] / info["seconds"])]
|
||||
)
|
||||
msg.row(data, widths=table_widths, aligns=table_aligns)
|
||||
|
||||
|
@ -628,18 +639,15 @@ def verify_cli_args(
|
|||
if code_path is not None:
|
||||
if not code_path.exists():
|
||||
msg.fail("Path to Python code not found", code_path, exits=1)
|
||||
try:
|
||||
util.import_file("python_code", code_path)
|
||||
except Exception as e:
|
||||
msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1)
|
||||
util.import_file("python_code", code_path)
|
||||
if init_tok2vec is not None and not init_tok2vec.exists():
|
||||
msg.fail("Can't find pretrained tok2vec", init_tok2vec, exits=1)
|
||||
|
||||
if num_workers and num_workers > 1:
|
||||
try:
|
||||
import spacy_ray
|
||||
import ray
|
||||
except ImportError:
|
||||
msg.fail("Need to `pip install spacy_ray` to use distributed training!", exits=1)
|
||||
msg.fail("Need to `pip install ray` to use distributed training!", exits=1)
|
||||
|
||||
|
||||
def verify_textcat_config(nlp, nlp_config):
|
||||
|
|
Loading…
Reference in New Issue
Block a user