Improve train CLI with base model (#4911)

Improve train CLI with a provided base model so that you can:

* add a new component
* extend an existing component
* replace an existing component

When the final model and best model are saved, reenable any disabled
components and merge the meta information to include the full pipeline
and accuracy information for all components in the base model plus the
newly added components if needed.
This commit is contained in:
adrianeboyd 2020-01-16 01:58:51 +01:00 committed by Matthew Honnibal
parent 718704022a
commit 90c52128dc

View File

@ -30,6 +30,7 @@ from .. import about
raw_text=("Path to jsonl file with unlabelled text documents.", "option", "rt", Path), raw_text=("Path to jsonl file with unlabelled text documents.", "option", "rt", Path),
base_model=("Name of model to update (optional)", "option", "b", str), base_model=("Name of model to update (optional)", "option", "b", str),
pipeline=("Comma-separated names of pipeline components", "option", "p", str), pipeline=("Comma-separated names of pipeline components", "option", "p", str),
replace_components=("Replace components from base model", "flag", "R", bool),
vectors=("Model to load vectors from", "option", "v", str), vectors=("Model to load vectors from", "option", "v", str),
n_iter=("Number of iterations", "option", "n", int), n_iter=("Number of iterations", "option", "n", int),
n_early_stopping=("Maximum number of training epochs without dev accuracy improvement", "option", "ne", int), n_early_stopping=("Maximum number of training epochs without dev accuracy improvement", "option", "ne", int),
@ -60,6 +61,7 @@ def train(
raw_text=None, raw_text=None,
base_model=None, base_model=None,
pipeline="tagger,parser,ner", pipeline="tagger,parser,ner",
replace_components=False,
vectors=None, vectors=None,
n_iter=30, n_iter=30,
n_early_stopping=None, n_early_stopping=None,
@ -142,6 +144,8 @@ def train(
# the model and make sure the pipeline matches the pipeline setting. If # the model and make sure the pipeline matches the pipeline setting. If
# training starts from a blank model, intitalize the language class. # training starts from a blank model, intitalize the language class.
pipeline = [p.strip() for p in pipeline.split(",")] pipeline = [p.strip() for p in pipeline.split(",")]
disabled_pipes = None
pipes_added = False
msg.text("Training pipeline: {}".format(pipeline)) msg.text("Training pipeline: {}".format(pipeline))
if base_model: if base_model:
msg.text("Starting with base model '{}'".format(base_model)) msg.text("Starting with base model '{}'".format(base_model))
@ -152,20 +156,24 @@ def train(
"`lang` argument ('{}') ".format(nlp.lang, lang), "`lang` argument ('{}') ".format(nlp.lang, lang),
exits=1, exits=1,
) )
nlp.disable_pipes([p for p in nlp.pipe_names if p not in pipeline])
for pipe in pipeline: for pipe in pipeline:
pipe_cfg = {}
if pipe == "parser":
pipe_cfg = {"learn_tokens": learn_tokens}
elif pipe == "textcat":
pipe_cfg = {
"exclusive_classes": not textcat_multilabel,
"architecture": textcat_arch,
"positive_label": textcat_positive_label,
}
if pipe not in nlp.pipe_names: if pipe not in nlp.pipe_names:
if pipe == "parser": msg.text("Adding component to base model '{}'".format(pipe))
pipe_cfg = {"learn_tokens": learn_tokens}
elif pipe == "textcat":
pipe_cfg = {
"exclusive_classes": not textcat_multilabel,
"architecture": textcat_arch,
"positive_label": textcat_positive_label,
}
else:
pipe_cfg = {}
nlp.add_pipe(nlp.create_pipe(pipe, config=pipe_cfg)) nlp.add_pipe(nlp.create_pipe(pipe, config=pipe_cfg))
pipes_added = True
elif replace_components:
msg.text("Replacing component from base model '{}'".format(pipe))
nlp.replace_pipe(pipe, nlp.create_pipe(pipe, config=pipe_cfg))
pipes_added = True
else: else:
if pipe == "textcat": if pipe == "textcat":
textcat_cfg = nlp.get_pipe("textcat").cfg textcat_cfg = nlp.get_pipe("textcat").cfg
@ -174,11 +182,6 @@ def train(
"architecture": textcat_cfg["architecture"], "architecture": textcat_cfg["architecture"],
"positive_label": textcat_cfg["positive_label"], "positive_label": textcat_cfg["positive_label"],
} }
pipe_cfg = {
"exclusive_classes": not textcat_multilabel,
"architecture": textcat_arch,
"positive_label": textcat_positive_label,
}
if base_cfg != pipe_cfg: if base_cfg != pipe_cfg:
msg.fail( msg.fail(
"The base textcat model configuration does" "The base textcat model configuration does"
@ -188,6 +191,8 @@ def train(
), ),
exits=1, exits=1,
) )
msg.text("Extending component from base model '{}'".format(pipe))
disabled_pipes = nlp.disable_pipes([p for p in nlp.pipe_names if p not in pipeline])
else: else:
msg.text("Starting with blank model '{}'".format(lang)) msg.text("Starting with blank model '{}'".format(lang))
lang_cls = util.get_lang_class(lang) lang_cls = util.get_lang_class(lang)
@ -227,7 +232,7 @@ def train(
corpus = GoldCorpus(train_path, dev_path, limit=n_examples) corpus = GoldCorpus(train_path, dev_path, limit=n_examples)
n_train_words = corpus.count_train() n_train_words = corpus.count_train()
if base_model: if base_model and not pipes_added:
# Start with an existing model, use default optimizer # Start with an existing model, use default optimizer
optimizer = create_default_optimizer(Model.ops) optimizer = create_default_optimizer(Model.ops)
else: else:
@ -243,7 +248,7 @@ def train(
# Verify textcat config # Verify textcat config
if "textcat" in pipeline: if "textcat" in pipeline:
textcat_labels = nlp.get_pipe("textcat").cfg["labels"] textcat_labels = nlp.get_pipe("textcat").cfg.get("labels", [])
if textcat_positive_label and textcat_positive_label not in textcat_labels: if textcat_positive_label and textcat_positive_label not in textcat_labels:
msg.fail( msg.fail(
"The textcat_positive_label (tpl) '{}' does not match any " "The textcat_positive_label (tpl) '{}' does not match any "
@ -426,11 +431,16 @@ def train(
"cpu": cpu_wps, "cpu": cpu_wps,
"gpu": gpu_wps, "gpu": gpu_wps,
} }
meta["accuracy"] = scorer.scores meta.setdefault("accuracy", {})
for component in nlp.pipe_names:
for metric in _get_metrics(component):
meta["accuracy"][metric] = scorer.scores[metric]
else: else:
meta.setdefault("beam_accuracy", {}) meta.setdefault("beam_accuracy", {})
meta.setdefault("beam_speed", {}) meta.setdefault("beam_speed", {})
meta["beam_accuracy"][beam_width] = scorer.scores for component in nlp.pipe_names:
for metric in _get_metrics(component):
meta["beam_accuracy"][metric] = scorer.scores[metric]
meta["beam_speed"][beam_width] = { meta["beam_speed"][beam_width] = {
"nwords": nwords, "nwords": nwords,
"cpu": cpu_wps, "cpu": cpu_wps,
@ -486,12 +496,16 @@ def train(
) )
break break
finally: finally:
best_pipes = nlp.pipe_names
if disabled_pipes:
disabled_pipes.restore()
with nlp.use_params(optimizer.averages): with nlp.use_params(optimizer.averages):
final_model_path = output_path / "model-final" final_model_path = output_path / "model-final"
nlp.to_disk(final_model_path) nlp.to_disk(final_model_path)
final_meta = srsly.read_json(output_path / "model-final" / "meta.json")
msg.good("Saved model to output directory", final_model_path) msg.good("Saved model to output directory", final_model_path)
with msg.loading("Creating best model..."): with msg.loading("Creating best model..."):
best_model_path = _collate_best_model(meta, output_path, nlp.pipe_names) best_model_path = _collate_best_model(final_meta, output_path, best_pipes)
msg.good("Created best model", best_model_path) msg.good("Created best model", best_model_path)
@ -549,6 +563,7 @@ def _load_pretrained_tok2vec(nlp, loc):
def _collate_best_model(meta, output_path, components): def _collate_best_model(meta, output_path, components):
bests = {} bests = {}
meta.setdefault("accuracy", {})
for component in components: for component in components:
bests[component] = _find_best(output_path, component) bests[component] = _find_best(output_path, component)
best_dest = output_path / "model-best" best_dest = output_path / "model-best"
@ -580,11 +595,13 @@ def _find_best(experiment_dir, component):
def _get_metrics(component): def _get_metrics(component):
if component == "parser": if component == "parser":
return ("las", "uas", "token_acc") return ("las", "uas", "las_per_type", "token_acc")
elif component == "tagger": elif component == "tagger":
return ("tags_acc",) return ("tags_acc",)
elif component == "ner": elif component == "ner":
return ("ents_f", "ents_p", "ents_r") return ("ents_f", "ents_p", "ents_r", "ents_per_type")
elif component == "textcat":
return ("textcat_score",)
return ("token_acc",) return ("token_acc",)