From 4c5d6b13c842889834ca9047e5d96a4fb8075b3a Mon Sep 17 00:00:00 2001 From: Matthw Honnibal Date: Thu, 16 Jul 2020 19:16:43 +0200 Subject: [PATCH] Changes to train for parallel training. Temporary -- dont merge --- spacy/cli/train.py | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/spacy/cli/train.py b/spacy/cli/train.py index 374b27a52..f6e585edc 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -171,6 +171,7 @@ def train_cli( } ) else: + msg.info(f"Loading config from: {config_path}") nlp, config = load_nlp_and_config(config_path) corpus = Corpus(train_path, dev_path, limit=config["training"]["limit"]) @@ -203,6 +204,8 @@ def load_nlp_and_config(config_path): nlp_config = config["nlp"] config = util.load_config(config_path, create_objects=True) nlp = util.load_model_from_config(nlp_config) + # TODO: This is hacky, but temporary convenience... + config["_nlp_config"] = nlp_config return nlp, config @@ -216,12 +219,12 @@ def train( weights_data: Optional[bytes] = None, omit_extra_lookups: bool = False, disable_tqdm: bool = False, - randomization_index: int = 0, - num_workers=1 + worker_id: int = 0, + num_workers=1, ) -> None: - msg.info(f"Loading config from: {config_path}") # Read the config first without creating objects, to get to the original nlp_config training = config["training"] + nlp_config = config["_nlp_config"] msg.info("Creating nlp from config") optimizer = training["optimizer"] if "textcat" in nlp_config["pipeline"]: @@ -271,7 +274,7 @@ def train( tok2vec.from_bytes(weights_data) msg.info("Loading training corpus") - train_batches = create_train_batches(nlp, corpus, training, randomization_index) + train_batches = create_train_batches(nlp, corpus, training, worker_id) evaluate = create_evaluation_callback(nlp, optimizer, corpus, training) # Create iterator, which yields out info after each optimization step. @@ -289,7 +292,7 @@ def train( raw_text=raw_text, ) msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}") - print_row = setup_printer(training, nlp) + print_row = setup_printer(training, nlp.pipe_names) tqdm_args = dict(total=training["eval_frequency"], leave=False, disable=disable_tqdm) try: progress = tqdm.tqdm(**tqdm_args) @@ -476,6 +479,8 @@ def train_while_improving( ] raw_batches = util.minibatch(raw_examples, size=8) + start_time = timer() + words_seen = 0 for step, (epoch, batch) in enumerate(train_data): dropout = next(dropouts) with nlp.select_pipes(enable=to_enable): @@ -497,13 +502,16 @@ def train_while_improving( else: score, other_scores = (None, None) is_best_checkpoint = None + words_seen += sum(len(eg) for eg in batch) info = { "epoch": epoch, "step": step, + "words": words_seen, "score": score, "other_scores": other_scores, "losses": losses, "checkpoints": results, + "seconds": int(timer() - start_time) } yield batch, info, is_best_checkpoint if is_best_checkpoint is not None: @@ -532,14 +540,16 @@ def subdivide_batch(batch, accumulate_gradient): yield subbatch -def setup_printer(training, nlp): +def setup_printer(training, pipe_names): score_cols = training["scores"] score_widths = [max(len(col), 6) for col in score_cols] - loss_cols = [f"Loss {pipe}" for pipe in nlp.pipe_names] + loss_cols = [f"Loss {pipe}" for pipe in pipe_names] loss_widths = [max(len(col), 8) for col in loss_cols] table_header = ["E", "#"] + loss_cols + score_cols + ["Score"] table_header = [col.upper() for col in table_header] table_widths = [3, 6] + loss_widths + score_widths + [6] + table_header.append("WPS (TRAIN)") + table_widths.append(len(table_header[-1])) table_aligns = ["r" for _ in table_widths] msg.row(table_header, widths=table_widths) @@ -549,7 +559,7 @@ def setup_printer(training, nlp): try: losses = [ "{0:.2f}".format(float(info["losses"][pipe_name])) - for pipe_name in nlp.pipe_names + for pipe_name in pipe_names ] except KeyError as e: raise KeyError( @@ -575,6 +585,7 @@ def setup_printer(training, nlp): + losses + scores + ["{0:.2f}".format(float(info["score"]))] + + ["%d" % (info["words"] / info["seconds"])] ) msg.row(data, widths=table_widths, aligns=table_aligns) @@ -628,18 +639,15 @@ def verify_cli_args( if code_path is not None: if not code_path.exists(): msg.fail("Path to Python code not found", code_path, exits=1) - try: - util.import_file("python_code", code_path) - except Exception as e: - msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1) + util.import_file("python_code", code_path) if init_tok2vec is not None and not init_tok2vec.exists(): msg.fail("Can't find pretrained tok2vec", init_tok2vec, exits=1) if num_workers and num_workers > 1: try: - import spacy_ray + import ray except ImportError: - msg.fail("Need to `pip install spacy_ray` to use distributed training!", exits=1) + msg.fail("Need to `pip install ray` to use distributed training!", exits=1) def verify_textcat_config(nlp, nlp_config):