Improve train CLI with base model (#4911)

Improve train CLI with a provided base model so that you can:

* add a new component
* extend an existing component
* replace an existing component

When the final model and best model are saved, reenable any disabled
components and merge the meta information to include the full pipeline
and accuracy information for all components in the base model plus the
newly added components if needed.
This commit is contained in:
adrianeboyd 2020-01-16 01:58:51 +01:00 committed by Matthew Honnibal
parent 718704022a
commit 90c52128dc

View File

@ -30,6 +30,7 @@ from .. import about
raw_text=("Path to jsonl file with unlabelled text documents.", "option", "rt", Path),
base_model=("Name of model to update (optional)", "option", "b", str),
pipeline=("Comma-separated names of pipeline components", "option", "p", str),
replace_components=("Replace components from base model", "flag", "R", bool),
vectors=("Model to load vectors from", "option", "v", str),
n_iter=("Number of iterations", "option", "n", int),
n_early_stopping=("Maximum number of training epochs without dev accuracy improvement", "option", "ne", int),
@ -60,6 +61,7 @@ def train(
raw_text=None,
base_model=None,
pipeline="tagger,parser,ner",
replace_components=False,
vectors=None,
n_iter=30,
n_early_stopping=None,
@ -142,6 +144,8 @@ def train(
# the model and make sure the pipeline matches the pipeline setting. If
# training starts from a blank model, intitalize the language class.
pipeline = [p.strip() for p in pipeline.split(",")]
disabled_pipes = None
pipes_added = False
msg.text("Training pipeline: {}".format(pipeline))
if base_model:
msg.text("Starting with base model '{}'".format(base_model))
@ -152,20 +156,24 @@ def train(
"`lang` argument ('{}') ".format(nlp.lang, lang),
exits=1,
)
nlp.disable_pipes([p for p in nlp.pipe_names if p not in pipeline])
for pipe in pipeline:
pipe_cfg = {}
if pipe == "parser":
pipe_cfg = {"learn_tokens": learn_tokens}
elif pipe == "textcat":
pipe_cfg = {
"exclusive_classes": not textcat_multilabel,
"architecture": textcat_arch,
"positive_label": textcat_positive_label,
}
if pipe not in nlp.pipe_names:
if pipe == "parser":
pipe_cfg = {"learn_tokens": learn_tokens}
elif pipe == "textcat":
pipe_cfg = {
"exclusive_classes": not textcat_multilabel,
"architecture": textcat_arch,
"positive_label": textcat_positive_label,
}
else:
pipe_cfg = {}
msg.text("Adding component to base model '{}'".format(pipe))
nlp.add_pipe(nlp.create_pipe(pipe, config=pipe_cfg))
pipes_added = True
elif replace_components:
msg.text("Replacing component from base model '{}'".format(pipe))
nlp.replace_pipe(pipe, nlp.create_pipe(pipe, config=pipe_cfg))
pipes_added = True
else:
if pipe == "textcat":
textcat_cfg = nlp.get_pipe("textcat").cfg
@ -174,11 +182,6 @@ def train(
"architecture": textcat_cfg["architecture"],
"positive_label": textcat_cfg["positive_label"],
}
pipe_cfg = {
"exclusive_classes": not textcat_multilabel,
"architecture": textcat_arch,
"positive_label": textcat_positive_label,
}
if base_cfg != pipe_cfg:
msg.fail(
"The base textcat model configuration does"
@ -188,6 +191,8 @@ def train(
),
exits=1,
)
msg.text("Extending component from base model '{}'".format(pipe))
disabled_pipes = nlp.disable_pipes([p for p in nlp.pipe_names if p not in pipeline])
else:
msg.text("Starting with blank model '{}'".format(lang))
lang_cls = util.get_lang_class(lang)
@ -227,7 +232,7 @@ def train(
corpus = GoldCorpus(train_path, dev_path, limit=n_examples)
n_train_words = corpus.count_train()
if base_model:
if base_model and not pipes_added:
# Start with an existing model, use default optimizer
optimizer = create_default_optimizer(Model.ops)
else:
@ -243,7 +248,7 @@ def train(
# Verify textcat config
if "textcat" in pipeline:
textcat_labels = nlp.get_pipe("textcat").cfg["labels"]
textcat_labels = nlp.get_pipe("textcat").cfg.get("labels", [])
if textcat_positive_label and textcat_positive_label not in textcat_labels:
msg.fail(
"The textcat_positive_label (tpl) '{}' does not match any "
@ -426,11 +431,16 @@ def train(
"cpu": cpu_wps,
"gpu": gpu_wps,
}
meta["accuracy"] = scorer.scores
meta.setdefault("accuracy", {})
for component in nlp.pipe_names:
for metric in _get_metrics(component):
meta["accuracy"][metric] = scorer.scores[metric]
else:
meta.setdefault("beam_accuracy", {})
meta.setdefault("beam_speed", {})
meta["beam_accuracy"][beam_width] = scorer.scores
for component in nlp.pipe_names:
for metric in _get_metrics(component):
meta["beam_accuracy"][metric] = scorer.scores[metric]
meta["beam_speed"][beam_width] = {
"nwords": nwords,
"cpu": cpu_wps,
@ -486,12 +496,16 @@ def train(
)
break
finally:
best_pipes = nlp.pipe_names
if disabled_pipes:
disabled_pipes.restore()
with nlp.use_params(optimizer.averages):
final_model_path = output_path / "model-final"
nlp.to_disk(final_model_path)
final_meta = srsly.read_json(output_path / "model-final" / "meta.json")
msg.good("Saved model to output directory", final_model_path)
with msg.loading("Creating best model..."):
best_model_path = _collate_best_model(meta, output_path, nlp.pipe_names)
best_model_path = _collate_best_model(final_meta, output_path, best_pipes)
msg.good("Created best model", best_model_path)
@ -549,6 +563,7 @@ def _load_pretrained_tok2vec(nlp, loc):
def _collate_best_model(meta, output_path, components):
bests = {}
meta.setdefault("accuracy", {})
for component in components:
bests[component] = _find_best(output_path, component)
best_dest = output_path / "model-best"
@ -580,11 +595,13 @@ def _find_best(experiment_dir, component):
def _get_metrics(component):
if component == "parser":
return ("las", "uas", "token_acc")
return ("las", "uas", "las_per_type", "token_acc")
elif component == "tagger":
return ("tags_acc",)
elif component == "ner":
return ("ents_f", "ents_p", "ents_r")
return ("ents_f", "ents_p", "ents_r", "ents_per_type")
elif component == "textcat":
return ("textcat_score",)
return ("token_acc",)