From 11fa0658f739b31effadfef5c2f277674fc1a7b8 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sat, 20 Jun 2020 20:12:19 +0200 Subject: [PATCH] Work on train script --- spacy/cli/train.py | 151 ++++++++++++++++++++------------------------- 1 file changed, 66 insertions(+), 85 deletions(-) diff --git a/spacy/cli/train.py b/spacy/cli/train.py index fb4347158..64eb89d13 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -12,7 +12,7 @@ import thinc.schedules from thinc.api import Model, use_pytorch_for_gpu_memory import random -from ..gold import GoldCorpus +from ..gold.corpus_docbin import Corpus from ..lookups import Lookups from .. import util from ..errors import Errors @@ -148,26 +148,8 @@ def train_cli( command. """ util.set_env_log(verbose) + verify_cli_args(**locals()) - # Make sure all files and paths exists if they are needed - if not config_path or not config_path.exists(): - msg.fail("Config file not found", config_path, exits=1) - if not train_path or not train_path.exists(): - msg.fail("Training data not found", train_path, exits=1) - if not dev_path or not dev_path.exists(): - msg.fail("Development data not found", dev_path, exits=1) - if output_path is not None: - if not output_path.exists(): - output_path.mkdir() - msg.good(f"Created output directory: {output_path}") - elif output_path.exists() and [p for p in output_path.iterdir() if p.is_dir()]: - msg.warn( - "Output directory is not empty.", - "This can lead to unintended side effects when saving the model. " - "Please use an empty directory or a different path instead. If " - "the specified output path doesn't exist, the directory will be " - "created for you.", - ) if raw_text is not None: raw_text = list(srsly.read_jsonl(raw_text)) tag_map = {} @@ -176,9 +158,7 @@ def train_cli( weights_data = None if init_tok2vec is not None: - if not init_tok2vec.exists(): - msg.fail("Can't find pretrained tok2vec", init_tok2vec, exits=1) - with init_tok2vec.open("rb") as file_: + with init_tok2vec.open("rb") as file_: weights_data = file_.read() if use_gpu >= 0: @@ -198,6 +178,7 @@ def train_cli( ) + def train( config_path, data_paths, @@ -221,60 +202,9 @@ def train( nlp = util.load_model_from_config(nlp_config) optimizer = training["optimizer"] limit = training["limit"] - msg.info("Loading training corpus") - corpus = GoldCorpus(data_paths["train"], data_paths["dev"], limit=limit) - # verify textcat config + corpus = Corpus(data_paths["train"], data_paths["dev"], limit=limit) if "textcat" in nlp_config["pipeline"]: - textcat_labels = set(nlp.get_pipe("textcat").labels) - textcat_multilabel = not nlp_config["pipeline"]["textcat"]["model"]["exclusive_classes"] - - # check whether the setting 'exclusive_classes' corresponds to the provided training data - if textcat_multilabel: - multilabel_found = False - for eg in corpus.train_annotations: - cats = eg.reference.cats - textcat_labels.update(cats.keys()) - if list(cats.values()).count(1.0) != 1: - multilabel_found = True - if not multilabel_found: - msg.warn( - "The textcat training instances look like they have " - "mutually exclusive classes. Set 'exclusive_classes' " - "to 'true' in the config to train a classifier with " - "mutually exclusive classes more accurately." - ) - else: - for eg in corpus.train_annotations: - cats = eg.reference.cats - textcat_labels.update(cats.keys()) - if list(cats.values()).count(1.0) != 1: - msg.fail( - "Some textcat training instances do not have exactly " - "one positive label. Set 'exclusive_classes' " - "to 'false' in the config to train a classifier with classes " - "that are not mutually exclusive." - ) - msg.info(f"Initialized textcat component for {len(textcat_labels)} unique labels") - nlp.get_pipe("textcat").labels = tuple(textcat_labels) - - # if 'positive_label' is provided: double check whether it's in the data and the task is binary - if nlp_config["pipeline"]["textcat"].get("positive_label", None): - textcat_labels = nlp.get_pipe("textcat").cfg.get("labels", []) - pos_label = nlp_config["pipeline"]["textcat"]["positive_label"] - if pos_label not in textcat_labels: - msg.fail( - f"The textcat's 'positive_label' config setting '{pos_label}' " - f"does not match any label in the training data.", - exits=1, - ) - if len(textcat_labels) != 2: - msg.fail( - f"A textcat 'positive_label' '{pos_label}' was " - f"provided for training data that does not appear to be a " - f"binary classification problem with two labels.", - exits=1, - ) - + verify_textcat_config(nlp, nlp_config) if training.get("resume", False): msg.info("Resuming training") nlp.resume_training() @@ -312,6 +242,7 @@ def train( ) tok2vec.from_bytes(weights_data) + msg.info("Loading training corpus") train_batches = create_train_batches(nlp, corpus, training) evaluate = create_evaluation_callback(nlp, optimizer, corpus, training) @@ -368,15 +299,7 @@ def train( def create_train_batches(nlp, corpus, cfg): epochs_todo = cfg.get("max_epochs", 0) while True: - train_examples = list( - corpus.train_dataset( - nlp, - orth_variant_level=cfg["orth_variant_level"], - gold_preproc=cfg["gold_preproc"], - max_length=cfg["max_length"], - ignore_misaligned=True, - ) - ) + train_examples = list(corpus.train_dataset(nlp)) if len(train_examples) == 0: raise ValueError(Errors.E988) @@ -598,3 +521,61 @@ def update_meta(training, nlp, info): nlp.meta["performance"][metric] = info["other_scores"][metric] for pipe_name in nlp.pipe_names: nlp.meta["performance"][f"{pipe_name}_loss"] = info["losses"][pipe_name] + + +def verify_cli_args( + train_path, + dev_path, + config_path, + output_path=None, + init_tok2vec=None, + raw_text=None, + verbose=False, + use_gpu=-1, + tag_map_path=None, + omit_extra_lookups=False, +): + # Make sure all files and paths exists if they are needed + if not config_path or not config_path.exists(): + msg.fail("Config file not found", config_path, exits=1) + if not train_path or not train_path.exists(): + msg.fail("Training data not found", train_path, exits=1) + if not dev_path or not dev_path.exists(): + msg.fail("Development data not found", dev_path, exits=1) + if output_path is not None: + if not output_path.exists(): + output_path.mkdir() + msg.good(f"Created output directory: {output_path}") + elif output_path.exists() and [p for p in output_path.iterdir() if p.is_dir()]: + msg.warn( + "Output directory is not empty.", + "This can lead to unintended side effects when saving the model. " + "Please use an empty directory or a different path instead. If " + "the specified output path doesn't exist, the directory will be " + "created for you.", + ) + if init_tok2vec is not None and not init_tok2vec.exists(): + msg.fail("Can't find pretrained tok2vec", init_tok2vec, exits=1) + + +def verify_textcat_config(nlp, nlp_config): + msg.info(f"Initialized textcat component for {len(textcat_labels)} unique labels") + nlp.get_pipe("textcat").labels = tuple(textcat_labels) + # if 'positive_label' is provided: double check whether it's in the data and + # the task is binary + if nlp_config["pipeline"]["textcat"].get("positive_label", None): + textcat_labels = nlp.get_pipe("textcat").cfg.get("labels", []) + pos_label = nlp_config["pipeline"]["textcat"]["positive_label"] + if pos_label not in textcat_labels: + msg.fail( + f"The textcat's 'positive_label' config setting '{pos_label}' " + f"does not match any label in the training data.", + exits=1, + ) + if len(textcat_labels) != 2: + msg.fail( + f"A textcat 'positive_label' '{pos_label}' was " + f"provided for training data that does not appear to be a " + f"binary classification problem with two labels.", + exits=1, + )