mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Improve train CLI with base model (#4911)
Improve train CLI with a provided base model so that you can: * add a new component * extend an existing component * replace an existing component When the final model and best model are saved, reenable any disabled components and merge the meta information to include the full pipeline and accuracy information for all components in the base model plus the newly added components if needed.
This commit is contained in:
parent
718704022a
commit
90c52128dc
|
@ -30,6 +30,7 @@ from .. import about
|
|||
raw_text=("Path to jsonl file with unlabelled text documents.", "option", "rt", Path),
|
||||
base_model=("Name of model to update (optional)", "option", "b", str),
|
||||
pipeline=("Comma-separated names of pipeline components", "option", "p", str),
|
||||
replace_components=("Replace components from base model", "flag", "R", bool),
|
||||
vectors=("Model to load vectors from", "option", "v", str),
|
||||
n_iter=("Number of iterations", "option", "n", int),
|
||||
n_early_stopping=("Maximum number of training epochs without dev accuracy improvement", "option", "ne", int),
|
||||
|
@ -60,6 +61,7 @@ def train(
|
|||
raw_text=None,
|
||||
base_model=None,
|
||||
pipeline="tagger,parser,ner",
|
||||
replace_components=False,
|
||||
vectors=None,
|
||||
n_iter=30,
|
||||
n_early_stopping=None,
|
||||
|
@ -142,6 +144,8 @@ def train(
|
|||
# the model and make sure the pipeline matches the pipeline setting. If
|
||||
# training starts from a blank model, intitalize the language class.
|
||||
pipeline = [p.strip() for p in pipeline.split(",")]
|
||||
disabled_pipes = None
|
||||
pipes_added = False
|
||||
msg.text("Training pipeline: {}".format(pipeline))
|
||||
if base_model:
|
||||
msg.text("Starting with base model '{}'".format(base_model))
|
||||
|
@ -152,20 +156,24 @@ def train(
|
|||
"`lang` argument ('{}') ".format(nlp.lang, lang),
|
||||
exits=1,
|
||||
)
|
||||
nlp.disable_pipes([p for p in nlp.pipe_names if p not in pipeline])
|
||||
for pipe in pipeline:
|
||||
pipe_cfg = {}
|
||||
if pipe == "parser":
|
||||
pipe_cfg = {"learn_tokens": learn_tokens}
|
||||
elif pipe == "textcat":
|
||||
pipe_cfg = {
|
||||
"exclusive_classes": not textcat_multilabel,
|
||||
"architecture": textcat_arch,
|
||||
"positive_label": textcat_positive_label,
|
||||
}
|
||||
if pipe not in nlp.pipe_names:
|
||||
if pipe == "parser":
|
||||
pipe_cfg = {"learn_tokens": learn_tokens}
|
||||
elif pipe == "textcat":
|
||||
pipe_cfg = {
|
||||
"exclusive_classes": not textcat_multilabel,
|
||||
"architecture": textcat_arch,
|
||||
"positive_label": textcat_positive_label,
|
||||
}
|
||||
else:
|
||||
pipe_cfg = {}
|
||||
msg.text("Adding component to base model '{}'".format(pipe))
|
||||
nlp.add_pipe(nlp.create_pipe(pipe, config=pipe_cfg))
|
||||
pipes_added = True
|
||||
elif replace_components:
|
||||
msg.text("Replacing component from base model '{}'".format(pipe))
|
||||
nlp.replace_pipe(pipe, nlp.create_pipe(pipe, config=pipe_cfg))
|
||||
pipes_added = True
|
||||
else:
|
||||
if pipe == "textcat":
|
||||
textcat_cfg = nlp.get_pipe("textcat").cfg
|
||||
|
@ -174,11 +182,6 @@ def train(
|
|||
"architecture": textcat_cfg["architecture"],
|
||||
"positive_label": textcat_cfg["positive_label"],
|
||||
}
|
||||
pipe_cfg = {
|
||||
"exclusive_classes": not textcat_multilabel,
|
||||
"architecture": textcat_arch,
|
||||
"positive_label": textcat_positive_label,
|
||||
}
|
||||
if base_cfg != pipe_cfg:
|
||||
msg.fail(
|
||||
"The base textcat model configuration does"
|
||||
|
@ -188,6 +191,8 @@ def train(
|
|||
),
|
||||
exits=1,
|
||||
)
|
||||
msg.text("Extending component from base model '{}'".format(pipe))
|
||||
disabled_pipes = nlp.disable_pipes([p for p in nlp.pipe_names if p not in pipeline])
|
||||
else:
|
||||
msg.text("Starting with blank model '{}'".format(lang))
|
||||
lang_cls = util.get_lang_class(lang)
|
||||
|
@ -227,7 +232,7 @@ def train(
|
|||
corpus = GoldCorpus(train_path, dev_path, limit=n_examples)
|
||||
n_train_words = corpus.count_train()
|
||||
|
||||
if base_model:
|
||||
if base_model and not pipes_added:
|
||||
# Start with an existing model, use default optimizer
|
||||
optimizer = create_default_optimizer(Model.ops)
|
||||
else:
|
||||
|
@ -243,7 +248,7 @@ def train(
|
|||
|
||||
# Verify textcat config
|
||||
if "textcat" in pipeline:
|
||||
textcat_labels = nlp.get_pipe("textcat").cfg["labels"]
|
||||
textcat_labels = nlp.get_pipe("textcat").cfg.get("labels", [])
|
||||
if textcat_positive_label and textcat_positive_label not in textcat_labels:
|
||||
msg.fail(
|
||||
"The textcat_positive_label (tpl) '{}' does not match any "
|
||||
|
@ -426,11 +431,16 @@ def train(
|
|||
"cpu": cpu_wps,
|
||||
"gpu": gpu_wps,
|
||||
}
|
||||
meta["accuracy"] = scorer.scores
|
||||
meta.setdefault("accuracy", {})
|
||||
for component in nlp.pipe_names:
|
||||
for metric in _get_metrics(component):
|
||||
meta["accuracy"][metric] = scorer.scores[metric]
|
||||
else:
|
||||
meta.setdefault("beam_accuracy", {})
|
||||
meta.setdefault("beam_speed", {})
|
||||
meta["beam_accuracy"][beam_width] = scorer.scores
|
||||
for component in nlp.pipe_names:
|
||||
for metric in _get_metrics(component):
|
||||
meta["beam_accuracy"][metric] = scorer.scores[metric]
|
||||
meta["beam_speed"][beam_width] = {
|
||||
"nwords": nwords,
|
||||
"cpu": cpu_wps,
|
||||
|
@ -486,12 +496,16 @@ def train(
|
|||
)
|
||||
break
|
||||
finally:
|
||||
best_pipes = nlp.pipe_names
|
||||
if disabled_pipes:
|
||||
disabled_pipes.restore()
|
||||
with nlp.use_params(optimizer.averages):
|
||||
final_model_path = output_path / "model-final"
|
||||
nlp.to_disk(final_model_path)
|
||||
final_meta = srsly.read_json(output_path / "model-final" / "meta.json")
|
||||
msg.good("Saved model to output directory", final_model_path)
|
||||
with msg.loading("Creating best model..."):
|
||||
best_model_path = _collate_best_model(meta, output_path, nlp.pipe_names)
|
||||
best_model_path = _collate_best_model(final_meta, output_path, best_pipes)
|
||||
msg.good("Created best model", best_model_path)
|
||||
|
||||
|
||||
|
@ -549,6 +563,7 @@ def _load_pretrained_tok2vec(nlp, loc):
|
|||
|
||||
def _collate_best_model(meta, output_path, components):
|
||||
bests = {}
|
||||
meta.setdefault("accuracy", {})
|
||||
for component in components:
|
||||
bests[component] = _find_best(output_path, component)
|
||||
best_dest = output_path / "model-best"
|
||||
|
@ -580,11 +595,13 @@ def _find_best(experiment_dir, component):
|
|||
|
||||
def _get_metrics(component):
|
||||
if component == "parser":
|
||||
return ("las", "uas", "token_acc")
|
||||
return ("las", "uas", "las_per_type", "token_acc")
|
||||
elif component == "tagger":
|
||||
return ("tags_acc",)
|
||||
elif component == "ner":
|
||||
return ("ents_f", "ents_p", "ents_r")
|
||||
return ("ents_f", "ents_p", "ents_r", "ents_per_type")
|
||||
elif component == "textcat":
|
||||
return ("textcat_score",)
|
||||
return ("token_acc",)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user