mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-11 09:00:36 +03:00
Work on train script
This commit is contained in:
parent
0de361cd00
commit
11fa0658f7
|
@ -12,7 +12,7 @@ import thinc.schedules
|
|||
from thinc.api import Model, use_pytorch_for_gpu_memory
|
||||
import random
|
||||
|
||||
from ..gold import GoldCorpus
|
||||
from ..gold.corpus_docbin import Corpus
|
||||
from ..lookups import Lookups
|
||||
from .. import util
|
||||
from ..errors import Errors
|
||||
|
@ -148,26 +148,8 @@ def train_cli(
|
|||
command.
|
||||
"""
|
||||
util.set_env_log(verbose)
|
||||
verify_cli_args(**locals())
|
||||
|
||||
# Make sure all files and paths exists if they are needed
|
||||
if not config_path or not config_path.exists():
|
||||
msg.fail("Config file not found", config_path, exits=1)
|
||||
if not train_path or not train_path.exists():
|
||||
msg.fail("Training data not found", train_path, exits=1)
|
||||
if not dev_path or not dev_path.exists():
|
||||
msg.fail("Development data not found", dev_path, exits=1)
|
||||
if output_path is not None:
|
||||
if not output_path.exists():
|
||||
output_path.mkdir()
|
||||
msg.good(f"Created output directory: {output_path}")
|
||||
elif output_path.exists() and [p for p in output_path.iterdir() if p.is_dir()]:
|
||||
msg.warn(
|
||||
"Output directory is not empty.",
|
||||
"This can lead to unintended side effects when saving the model. "
|
||||
"Please use an empty directory or a different path instead. If "
|
||||
"the specified output path doesn't exist, the directory will be "
|
||||
"created for you.",
|
||||
)
|
||||
if raw_text is not None:
|
||||
raw_text = list(srsly.read_jsonl(raw_text))
|
||||
tag_map = {}
|
||||
|
@ -176,8 +158,6 @@ def train_cli(
|
|||
|
||||
weights_data = None
|
||||
if init_tok2vec is not None:
|
||||
if not init_tok2vec.exists():
|
||||
msg.fail("Can't find pretrained tok2vec", init_tok2vec, exits=1)
|
||||
with init_tok2vec.open("rb") as file_:
|
||||
weights_data = file_.read()
|
||||
|
||||
|
@ -198,6 +178,7 @@ def train_cli(
|
|||
)
|
||||
|
||||
|
||||
|
||||
def train(
|
||||
config_path,
|
||||
data_paths,
|
||||
|
@ -221,60 +202,9 @@ def train(
|
|||
nlp = util.load_model_from_config(nlp_config)
|
||||
optimizer = training["optimizer"]
|
||||
limit = training["limit"]
|
||||
msg.info("Loading training corpus")
|
||||
corpus = GoldCorpus(data_paths["train"], data_paths["dev"], limit=limit)
|
||||
# verify textcat config
|
||||
corpus = Corpus(data_paths["train"], data_paths["dev"], limit=limit)
|
||||
if "textcat" in nlp_config["pipeline"]:
|
||||
textcat_labels = set(nlp.get_pipe("textcat").labels)
|
||||
textcat_multilabel = not nlp_config["pipeline"]["textcat"]["model"]["exclusive_classes"]
|
||||
|
||||
# check whether the setting 'exclusive_classes' corresponds to the provided training data
|
||||
if textcat_multilabel:
|
||||
multilabel_found = False
|
||||
for eg in corpus.train_annotations:
|
||||
cats = eg.reference.cats
|
||||
textcat_labels.update(cats.keys())
|
||||
if list(cats.values()).count(1.0) != 1:
|
||||
multilabel_found = True
|
||||
if not multilabel_found:
|
||||
msg.warn(
|
||||
"The textcat training instances look like they have "
|
||||
"mutually exclusive classes. Set 'exclusive_classes' "
|
||||
"to 'true' in the config to train a classifier with "
|
||||
"mutually exclusive classes more accurately."
|
||||
)
|
||||
else:
|
||||
for eg in corpus.train_annotations:
|
||||
cats = eg.reference.cats
|
||||
textcat_labels.update(cats.keys())
|
||||
if list(cats.values()).count(1.0) != 1:
|
||||
msg.fail(
|
||||
"Some textcat training instances do not have exactly "
|
||||
"one positive label. Set 'exclusive_classes' "
|
||||
"to 'false' in the config to train a classifier with classes "
|
||||
"that are not mutually exclusive."
|
||||
)
|
||||
msg.info(f"Initialized textcat component for {len(textcat_labels)} unique labels")
|
||||
nlp.get_pipe("textcat").labels = tuple(textcat_labels)
|
||||
|
||||
# if 'positive_label' is provided: double check whether it's in the data and the task is binary
|
||||
if nlp_config["pipeline"]["textcat"].get("positive_label", None):
|
||||
textcat_labels = nlp.get_pipe("textcat").cfg.get("labels", [])
|
||||
pos_label = nlp_config["pipeline"]["textcat"]["positive_label"]
|
||||
if pos_label not in textcat_labels:
|
||||
msg.fail(
|
||||
f"The textcat's 'positive_label' config setting '{pos_label}' "
|
||||
f"does not match any label in the training data.",
|
||||
exits=1,
|
||||
)
|
||||
if len(textcat_labels) != 2:
|
||||
msg.fail(
|
||||
f"A textcat 'positive_label' '{pos_label}' was "
|
||||
f"provided for training data that does not appear to be a "
|
||||
f"binary classification problem with two labels.",
|
||||
exits=1,
|
||||
)
|
||||
|
||||
verify_textcat_config(nlp, nlp_config)
|
||||
if training.get("resume", False):
|
||||
msg.info("Resuming training")
|
||||
nlp.resume_training()
|
||||
|
@ -312,6 +242,7 @@ def train(
|
|||
)
|
||||
tok2vec.from_bytes(weights_data)
|
||||
|
||||
msg.info("Loading training corpus")
|
||||
train_batches = create_train_batches(nlp, corpus, training)
|
||||
evaluate = create_evaluation_callback(nlp, optimizer, corpus, training)
|
||||
|
||||
|
@ -368,15 +299,7 @@ def train(
|
|||
def create_train_batches(nlp, corpus, cfg):
|
||||
epochs_todo = cfg.get("max_epochs", 0)
|
||||
while True:
|
||||
train_examples = list(
|
||||
corpus.train_dataset(
|
||||
nlp,
|
||||
orth_variant_level=cfg["orth_variant_level"],
|
||||
gold_preproc=cfg["gold_preproc"],
|
||||
max_length=cfg["max_length"],
|
||||
ignore_misaligned=True,
|
||||
)
|
||||
)
|
||||
train_examples = list(corpus.train_dataset(nlp))
|
||||
|
||||
if len(train_examples) == 0:
|
||||
raise ValueError(Errors.E988)
|
||||
|
@ -598,3 +521,61 @@ def update_meta(training, nlp, info):
|
|||
nlp.meta["performance"][metric] = info["other_scores"][metric]
|
||||
for pipe_name in nlp.pipe_names:
|
||||
nlp.meta["performance"][f"{pipe_name}_loss"] = info["losses"][pipe_name]
|
||||
|
||||
|
||||
def verify_cli_args(
|
||||
train_path,
|
||||
dev_path,
|
||||
config_path,
|
||||
output_path=None,
|
||||
init_tok2vec=None,
|
||||
raw_text=None,
|
||||
verbose=False,
|
||||
use_gpu=-1,
|
||||
tag_map_path=None,
|
||||
omit_extra_lookups=False,
|
||||
):
|
||||
# Make sure all files and paths exists if they are needed
|
||||
if not config_path or not config_path.exists():
|
||||
msg.fail("Config file not found", config_path, exits=1)
|
||||
if not train_path or not train_path.exists():
|
||||
msg.fail("Training data not found", train_path, exits=1)
|
||||
if not dev_path or not dev_path.exists():
|
||||
msg.fail("Development data not found", dev_path, exits=1)
|
||||
if output_path is not None:
|
||||
if not output_path.exists():
|
||||
output_path.mkdir()
|
||||
msg.good(f"Created output directory: {output_path}")
|
||||
elif output_path.exists() and [p for p in output_path.iterdir() if p.is_dir()]:
|
||||
msg.warn(
|
||||
"Output directory is not empty.",
|
||||
"This can lead to unintended side effects when saving the model. "
|
||||
"Please use an empty directory or a different path instead. If "
|
||||
"the specified output path doesn't exist, the directory will be "
|
||||
"created for you.",
|
||||
)
|
||||
if init_tok2vec is not None and not init_tok2vec.exists():
|
||||
msg.fail("Can't find pretrained tok2vec", init_tok2vec, exits=1)
|
||||
|
||||
|
||||
def verify_textcat_config(nlp, nlp_config):
|
||||
msg.info(f"Initialized textcat component for {len(textcat_labels)} unique labels")
|
||||
nlp.get_pipe("textcat").labels = tuple(textcat_labels)
|
||||
# if 'positive_label' is provided: double check whether it's in the data and
|
||||
# the task is binary
|
||||
if nlp_config["pipeline"]["textcat"].get("positive_label", None):
|
||||
textcat_labels = nlp.get_pipe("textcat").cfg.get("labels", [])
|
||||
pos_label = nlp_config["pipeline"]["textcat"]["positive_label"]
|
||||
if pos_label not in textcat_labels:
|
||||
msg.fail(
|
||||
f"The textcat's 'positive_label' config setting '{pos_label}' "
|
||||
f"does not match any label in the training data.",
|
||||
exits=1,
|
||||
)
|
||||
if len(textcat_labels) != 2:
|
||||
msg.fail(
|
||||
f"A textcat 'positive_label' '{pos_label}' was "
|
||||
f"provided for training data that does not appear to be a "
|
||||
f"binary classification problem with two labels.",
|
||||
exits=1,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user