Work on train script

This commit is contained in:
Matthew Honnibal 2020-06-20 20:12:19 +02:00
parent 0de361cd00
commit 11fa0658f7

View File

@ -12,7 +12,7 @@ import thinc.schedules
from thinc.api import Model, use_pytorch_for_gpu_memory from thinc.api import Model, use_pytorch_for_gpu_memory
import random import random
from ..gold import GoldCorpus from ..gold.corpus_docbin import Corpus
from ..lookups import Lookups from ..lookups import Lookups
from .. import util from .. import util
from ..errors import Errors from ..errors import Errors
@ -148,26 +148,8 @@ def train_cli(
command. command.
""" """
util.set_env_log(verbose) util.set_env_log(verbose)
verify_cli_args(**locals())
# Make sure all files and paths exists if they are needed
if not config_path or not config_path.exists():
msg.fail("Config file not found", config_path, exits=1)
if not train_path or not train_path.exists():
msg.fail("Training data not found", train_path, exits=1)
if not dev_path or not dev_path.exists():
msg.fail("Development data not found", dev_path, exits=1)
if output_path is not None:
if not output_path.exists():
output_path.mkdir()
msg.good(f"Created output directory: {output_path}")
elif output_path.exists() and [p for p in output_path.iterdir() if p.is_dir()]:
msg.warn(
"Output directory is not empty.",
"This can lead to unintended side effects when saving the model. "
"Please use an empty directory or a different path instead. If "
"the specified output path doesn't exist, the directory will be "
"created for you.",
)
if raw_text is not None: if raw_text is not None:
raw_text = list(srsly.read_jsonl(raw_text)) raw_text = list(srsly.read_jsonl(raw_text))
tag_map = {} tag_map = {}
@ -176,8 +158,6 @@ def train_cli(
weights_data = None weights_data = None
if init_tok2vec is not None: if init_tok2vec is not None:
if not init_tok2vec.exists():
msg.fail("Can't find pretrained tok2vec", init_tok2vec, exits=1)
with init_tok2vec.open("rb") as file_: with init_tok2vec.open("rb") as file_:
weights_data = file_.read() weights_data = file_.read()
@ -198,6 +178,7 @@ def train_cli(
) )
def train( def train(
config_path, config_path,
data_paths, data_paths,
@ -221,60 +202,9 @@ def train(
nlp = util.load_model_from_config(nlp_config) nlp = util.load_model_from_config(nlp_config)
optimizer = training["optimizer"] optimizer = training["optimizer"]
limit = training["limit"] limit = training["limit"]
msg.info("Loading training corpus") corpus = Corpus(data_paths["train"], data_paths["dev"], limit=limit)
corpus = GoldCorpus(data_paths["train"], data_paths["dev"], limit=limit)
# verify textcat config
if "textcat" in nlp_config["pipeline"]: if "textcat" in nlp_config["pipeline"]:
textcat_labels = set(nlp.get_pipe("textcat").labels) verify_textcat_config(nlp, nlp_config)
textcat_multilabel = not nlp_config["pipeline"]["textcat"]["model"]["exclusive_classes"]
# check whether the setting 'exclusive_classes' corresponds to the provided training data
if textcat_multilabel:
multilabel_found = False
for eg in corpus.train_annotations:
cats = eg.reference.cats
textcat_labels.update(cats.keys())
if list(cats.values()).count(1.0) != 1:
multilabel_found = True
if not multilabel_found:
msg.warn(
"The textcat training instances look like they have "
"mutually exclusive classes. Set 'exclusive_classes' "
"to 'true' in the config to train a classifier with "
"mutually exclusive classes more accurately."
)
else:
for eg in corpus.train_annotations:
cats = eg.reference.cats
textcat_labels.update(cats.keys())
if list(cats.values()).count(1.0) != 1:
msg.fail(
"Some textcat training instances do not have exactly "
"one positive label. Set 'exclusive_classes' "
"to 'false' in the config to train a classifier with classes "
"that are not mutually exclusive."
)
msg.info(f"Initialized textcat component for {len(textcat_labels)} unique labels")
nlp.get_pipe("textcat").labels = tuple(textcat_labels)
# if 'positive_label' is provided: double check whether it's in the data and the task is binary
if nlp_config["pipeline"]["textcat"].get("positive_label", None):
textcat_labels = nlp.get_pipe("textcat").cfg.get("labels", [])
pos_label = nlp_config["pipeline"]["textcat"]["positive_label"]
if pos_label not in textcat_labels:
msg.fail(
f"The textcat's 'positive_label' config setting '{pos_label}' "
f"does not match any label in the training data.",
exits=1,
)
if len(textcat_labels) != 2:
msg.fail(
f"A textcat 'positive_label' '{pos_label}' was "
f"provided for training data that does not appear to be a "
f"binary classification problem with two labels.",
exits=1,
)
if training.get("resume", False): if training.get("resume", False):
msg.info("Resuming training") msg.info("Resuming training")
nlp.resume_training() nlp.resume_training()
@ -312,6 +242,7 @@ def train(
) )
tok2vec.from_bytes(weights_data) tok2vec.from_bytes(weights_data)
msg.info("Loading training corpus")
train_batches = create_train_batches(nlp, corpus, training) train_batches = create_train_batches(nlp, corpus, training)
evaluate = create_evaluation_callback(nlp, optimizer, corpus, training) evaluate = create_evaluation_callback(nlp, optimizer, corpus, training)
@ -368,15 +299,7 @@ def train(
def create_train_batches(nlp, corpus, cfg): def create_train_batches(nlp, corpus, cfg):
epochs_todo = cfg.get("max_epochs", 0) epochs_todo = cfg.get("max_epochs", 0)
while True: while True:
train_examples = list( train_examples = list(corpus.train_dataset(nlp))
corpus.train_dataset(
nlp,
orth_variant_level=cfg["orth_variant_level"],
gold_preproc=cfg["gold_preproc"],
max_length=cfg["max_length"],
ignore_misaligned=True,
)
)
if len(train_examples) == 0: if len(train_examples) == 0:
raise ValueError(Errors.E988) raise ValueError(Errors.E988)
@ -598,3 +521,61 @@ def update_meta(training, nlp, info):
nlp.meta["performance"][metric] = info["other_scores"][metric] nlp.meta["performance"][metric] = info["other_scores"][metric]
for pipe_name in nlp.pipe_names: for pipe_name in nlp.pipe_names:
nlp.meta["performance"][f"{pipe_name}_loss"] = info["losses"][pipe_name] nlp.meta["performance"][f"{pipe_name}_loss"] = info["losses"][pipe_name]
def verify_cli_args(
train_path,
dev_path,
config_path,
output_path=None,
init_tok2vec=None,
raw_text=None,
verbose=False,
use_gpu=-1,
tag_map_path=None,
omit_extra_lookups=False,
):
# Make sure all files and paths exists if they are needed
if not config_path or not config_path.exists():
msg.fail("Config file not found", config_path, exits=1)
if not train_path or not train_path.exists():
msg.fail("Training data not found", train_path, exits=1)
if not dev_path or not dev_path.exists():
msg.fail("Development data not found", dev_path, exits=1)
if output_path is not None:
if not output_path.exists():
output_path.mkdir()
msg.good(f"Created output directory: {output_path}")
elif output_path.exists() and [p for p in output_path.iterdir() if p.is_dir()]:
msg.warn(
"Output directory is not empty.",
"This can lead to unintended side effects when saving the model. "
"Please use an empty directory or a different path instead. If "
"the specified output path doesn't exist, the directory will be "
"created for you.",
)
if init_tok2vec is not None and not init_tok2vec.exists():
msg.fail("Can't find pretrained tok2vec", init_tok2vec, exits=1)
def verify_textcat_config(nlp, nlp_config):
msg.info(f"Initialized textcat component for {len(textcat_labels)} unique labels")
nlp.get_pipe("textcat").labels = tuple(textcat_labels)
# if 'positive_label' is provided: double check whether it's in the data and
# the task is binary
if nlp_config["pipeline"]["textcat"].get("positive_label", None):
textcat_labels = nlp.get_pipe("textcat").cfg.get("labels", [])
pos_label = nlp_config["pipeline"]["textcat"]["positive_label"]
if pos_label not in textcat_labels:
msg.fail(
f"The textcat's 'positive_label' config setting '{pos_label}' "
f"does not match any label in the training data.",
exits=1,
)
if len(textcat_labels) != 2:
msg.fail(
f"A textcat 'positive_label' '{pos_label}' was "
f"provided for training data that does not appear to be a "
f"binary classification problem with two labels.",
exits=1,
)