mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-11 17:10:36 +03:00
Work on train script
This commit is contained in:
parent
0de361cd00
commit
11fa0658f7
|
@ -12,7 +12,7 @@ import thinc.schedules
|
||||||
from thinc.api import Model, use_pytorch_for_gpu_memory
|
from thinc.api import Model, use_pytorch_for_gpu_memory
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from ..gold import GoldCorpus
|
from ..gold.corpus_docbin import Corpus
|
||||||
from ..lookups import Lookups
|
from ..lookups import Lookups
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
|
@ -148,26 +148,8 @@ def train_cli(
|
||||||
command.
|
command.
|
||||||
"""
|
"""
|
||||||
util.set_env_log(verbose)
|
util.set_env_log(verbose)
|
||||||
|
verify_cli_args(**locals())
|
||||||
|
|
||||||
# Make sure all files and paths exists if they are needed
|
|
||||||
if not config_path or not config_path.exists():
|
|
||||||
msg.fail("Config file not found", config_path, exits=1)
|
|
||||||
if not train_path or not train_path.exists():
|
|
||||||
msg.fail("Training data not found", train_path, exits=1)
|
|
||||||
if not dev_path or not dev_path.exists():
|
|
||||||
msg.fail("Development data not found", dev_path, exits=1)
|
|
||||||
if output_path is not None:
|
|
||||||
if not output_path.exists():
|
|
||||||
output_path.mkdir()
|
|
||||||
msg.good(f"Created output directory: {output_path}")
|
|
||||||
elif output_path.exists() and [p for p in output_path.iterdir() if p.is_dir()]:
|
|
||||||
msg.warn(
|
|
||||||
"Output directory is not empty.",
|
|
||||||
"This can lead to unintended side effects when saving the model. "
|
|
||||||
"Please use an empty directory or a different path instead. If "
|
|
||||||
"the specified output path doesn't exist, the directory will be "
|
|
||||||
"created for you.",
|
|
||||||
)
|
|
||||||
if raw_text is not None:
|
if raw_text is not None:
|
||||||
raw_text = list(srsly.read_jsonl(raw_text))
|
raw_text = list(srsly.read_jsonl(raw_text))
|
||||||
tag_map = {}
|
tag_map = {}
|
||||||
|
@ -176,8 +158,6 @@ def train_cli(
|
||||||
|
|
||||||
weights_data = None
|
weights_data = None
|
||||||
if init_tok2vec is not None:
|
if init_tok2vec is not None:
|
||||||
if not init_tok2vec.exists():
|
|
||||||
msg.fail("Can't find pretrained tok2vec", init_tok2vec, exits=1)
|
|
||||||
with init_tok2vec.open("rb") as file_:
|
with init_tok2vec.open("rb") as file_:
|
||||||
weights_data = file_.read()
|
weights_data = file_.read()
|
||||||
|
|
||||||
|
@ -198,6 +178,7 @@ def train_cli(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
config_path,
|
config_path,
|
||||||
data_paths,
|
data_paths,
|
||||||
|
@ -221,60 +202,9 @@ def train(
|
||||||
nlp = util.load_model_from_config(nlp_config)
|
nlp = util.load_model_from_config(nlp_config)
|
||||||
optimizer = training["optimizer"]
|
optimizer = training["optimizer"]
|
||||||
limit = training["limit"]
|
limit = training["limit"]
|
||||||
msg.info("Loading training corpus")
|
corpus = Corpus(data_paths["train"], data_paths["dev"], limit=limit)
|
||||||
corpus = GoldCorpus(data_paths["train"], data_paths["dev"], limit=limit)
|
|
||||||
# verify textcat config
|
|
||||||
if "textcat" in nlp_config["pipeline"]:
|
if "textcat" in nlp_config["pipeline"]:
|
||||||
textcat_labels = set(nlp.get_pipe("textcat").labels)
|
verify_textcat_config(nlp, nlp_config)
|
||||||
textcat_multilabel = not nlp_config["pipeline"]["textcat"]["model"]["exclusive_classes"]
|
|
||||||
|
|
||||||
# check whether the setting 'exclusive_classes' corresponds to the provided training data
|
|
||||||
if textcat_multilabel:
|
|
||||||
multilabel_found = False
|
|
||||||
for eg in corpus.train_annotations:
|
|
||||||
cats = eg.reference.cats
|
|
||||||
textcat_labels.update(cats.keys())
|
|
||||||
if list(cats.values()).count(1.0) != 1:
|
|
||||||
multilabel_found = True
|
|
||||||
if not multilabel_found:
|
|
||||||
msg.warn(
|
|
||||||
"The textcat training instances look like they have "
|
|
||||||
"mutually exclusive classes. Set 'exclusive_classes' "
|
|
||||||
"to 'true' in the config to train a classifier with "
|
|
||||||
"mutually exclusive classes more accurately."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
for eg in corpus.train_annotations:
|
|
||||||
cats = eg.reference.cats
|
|
||||||
textcat_labels.update(cats.keys())
|
|
||||||
if list(cats.values()).count(1.0) != 1:
|
|
||||||
msg.fail(
|
|
||||||
"Some textcat training instances do not have exactly "
|
|
||||||
"one positive label. Set 'exclusive_classes' "
|
|
||||||
"to 'false' in the config to train a classifier with classes "
|
|
||||||
"that are not mutually exclusive."
|
|
||||||
)
|
|
||||||
msg.info(f"Initialized textcat component for {len(textcat_labels)} unique labels")
|
|
||||||
nlp.get_pipe("textcat").labels = tuple(textcat_labels)
|
|
||||||
|
|
||||||
# if 'positive_label' is provided: double check whether it's in the data and the task is binary
|
|
||||||
if nlp_config["pipeline"]["textcat"].get("positive_label", None):
|
|
||||||
textcat_labels = nlp.get_pipe("textcat").cfg.get("labels", [])
|
|
||||||
pos_label = nlp_config["pipeline"]["textcat"]["positive_label"]
|
|
||||||
if pos_label not in textcat_labels:
|
|
||||||
msg.fail(
|
|
||||||
f"The textcat's 'positive_label' config setting '{pos_label}' "
|
|
||||||
f"does not match any label in the training data.",
|
|
||||||
exits=1,
|
|
||||||
)
|
|
||||||
if len(textcat_labels) != 2:
|
|
||||||
msg.fail(
|
|
||||||
f"A textcat 'positive_label' '{pos_label}' was "
|
|
||||||
f"provided for training data that does not appear to be a "
|
|
||||||
f"binary classification problem with two labels.",
|
|
||||||
exits=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
if training.get("resume", False):
|
if training.get("resume", False):
|
||||||
msg.info("Resuming training")
|
msg.info("Resuming training")
|
||||||
nlp.resume_training()
|
nlp.resume_training()
|
||||||
|
@ -312,6 +242,7 @@ def train(
|
||||||
)
|
)
|
||||||
tok2vec.from_bytes(weights_data)
|
tok2vec.from_bytes(weights_data)
|
||||||
|
|
||||||
|
msg.info("Loading training corpus")
|
||||||
train_batches = create_train_batches(nlp, corpus, training)
|
train_batches = create_train_batches(nlp, corpus, training)
|
||||||
evaluate = create_evaluation_callback(nlp, optimizer, corpus, training)
|
evaluate = create_evaluation_callback(nlp, optimizer, corpus, training)
|
||||||
|
|
||||||
|
@ -368,15 +299,7 @@ def train(
|
||||||
def create_train_batches(nlp, corpus, cfg):
|
def create_train_batches(nlp, corpus, cfg):
|
||||||
epochs_todo = cfg.get("max_epochs", 0)
|
epochs_todo = cfg.get("max_epochs", 0)
|
||||||
while True:
|
while True:
|
||||||
train_examples = list(
|
train_examples = list(corpus.train_dataset(nlp))
|
||||||
corpus.train_dataset(
|
|
||||||
nlp,
|
|
||||||
orth_variant_level=cfg["orth_variant_level"],
|
|
||||||
gold_preproc=cfg["gold_preproc"],
|
|
||||||
max_length=cfg["max_length"],
|
|
||||||
ignore_misaligned=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(train_examples) == 0:
|
if len(train_examples) == 0:
|
||||||
raise ValueError(Errors.E988)
|
raise ValueError(Errors.E988)
|
||||||
|
@ -598,3 +521,61 @@ def update_meta(training, nlp, info):
|
||||||
nlp.meta["performance"][metric] = info["other_scores"][metric]
|
nlp.meta["performance"][metric] = info["other_scores"][metric]
|
||||||
for pipe_name in nlp.pipe_names:
|
for pipe_name in nlp.pipe_names:
|
||||||
nlp.meta["performance"][f"{pipe_name}_loss"] = info["losses"][pipe_name]
|
nlp.meta["performance"][f"{pipe_name}_loss"] = info["losses"][pipe_name]
|
||||||
|
|
||||||
|
|
||||||
|
def verify_cli_args(
|
||||||
|
train_path,
|
||||||
|
dev_path,
|
||||||
|
config_path,
|
||||||
|
output_path=None,
|
||||||
|
init_tok2vec=None,
|
||||||
|
raw_text=None,
|
||||||
|
verbose=False,
|
||||||
|
use_gpu=-1,
|
||||||
|
tag_map_path=None,
|
||||||
|
omit_extra_lookups=False,
|
||||||
|
):
|
||||||
|
# Make sure all files and paths exists if they are needed
|
||||||
|
if not config_path or not config_path.exists():
|
||||||
|
msg.fail("Config file not found", config_path, exits=1)
|
||||||
|
if not train_path or not train_path.exists():
|
||||||
|
msg.fail("Training data not found", train_path, exits=1)
|
||||||
|
if not dev_path or not dev_path.exists():
|
||||||
|
msg.fail("Development data not found", dev_path, exits=1)
|
||||||
|
if output_path is not None:
|
||||||
|
if not output_path.exists():
|
||||||
|
output_path.mkdir()
|
||||||
|
msg.good(f"Created output directory: {output_path}")
|
||||||
|
elif output_path.exists() and [p for p in output_path.iterdir() if p.is_dir()]:
|
||||||
|
msg.warn(
|
||||||
|
"Output directory is not empty.",
|
||||||
|
"This can lead to unintended side effects when saving the model. "
|
||||||
|
"Please use an empty directory or a different path instead. If "
|
||||||
|
"the specified output path doesn't exist, the directory will be "
|
||||||
|
"created for you.",
|
||||||
|
)
|
||||||
|
if init_tok2vec is not None and not init_tok2vec.exists():
|
||||||
|
msg.fail("Can't find pretrained tok2vec", init_tok2vec, exits=1)
|
||||||
|
|
||||||
|
|
||||||
|
def verify_textcat_config(nlp, nlp_config):
|
||||||
|
msg.info(f"Initialized textcat component for {len(textcat_labels)} unique labels")
|
||||||
|
nlp.get_pipe("textcat").labels = tuple(textcat_labels)
|
||||||
|
# if 'positive_label' is provided: double check whether it's in the data and
|
||||||
|
# the task is binary
|
||||||
|
if nlp_config["pipeline"]["textcat"].get("positive_label", None):
|
||||||
|
textcat_labels = nlp.get_pipe("textcat").cfg.get("labels", [])
|
||||||
|
pos_label = nlp_config["pipeline"]["textcat"]["positive_label"]
|
||||||
|
if pos_label not in textcat_labels:
|
||||||
|
msg.fail(
|
||||||
|
f"The textcat's 'positive_label' config setting '{pos_label}' "
|
||||||
|
f"does not match any label in the training data.",
|
||||||
|
exits=1,
|
||||||
|
)
|
||||||
|
if len(textcat_labels) != 2:
|
||||||
|
msg.fail(
|
||||||
|
f"A textcat 'positive_label' '{pos_label}' was "
|
||||||
|
f"provided for training data that does not appear to be a "
|
||||||
|
f"binary classification problem with two labels.",
|
||||||
|
exits=1,
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user