From 2c80b7c013a76cec87243c13f6c10e75352097a6 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 24 Jun 2018 23:39:52 +0200 Subject: [PATCH] Collate best model after training --- spacy/cli/train.py | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/spacy/cli/train.py b/spacy/cli/train.py index 5d6e2d55c..3426dfd65 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -185,6 +185,52 @@ def train(lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0, with nlp.use_params(optimizer.averages): final_model_path = output_path / 'model-final' nlp.to_disk(final_model_path) + components = [] + if not no_parser: + components.append('parser') + if not no_tagger: + components.append('tagger') + if not no_entity: + components.append('ner') + _collate_best_model(meta, output_path, components) + +def _collate_best_model(meta, output_path, components): + bests = {} + for component in components: + bests[component] = _find_best(output_path, component) + best_dest = output_path / 'model-best' + shutil.copytree(output_path / 'model-final', best_dest) + for component, best_component_src in bests.items(): + shutil.rmtree(best_dir / component) + shutil.copytree(best_component_src, best_dest / component) + with (best_component_src / 'accuracy.json').open() as file_: + accs = json.load(file_) + for metric in _get_metrics(component): + meta['accuracy'][metric] = accs[metric] + with (best_dest / 'meta.json').open('w') as file_: + file_.write(json_dumps(meta)) + + +def _find_best(experiment_dir, component): + accuracies = [] + for epoch_model in experiment_dir.iterdir(): + if epoch_model.is_dir() and epoch_model.parts[-1] != "model-final": + accs = json.load((epoch_model / "accuracy.json").open()) + scores = [accs.get(metric, 0.0) for metric in _get_metrics(component)] + accuracies.append((scores, epoch_model)) + if accuracies: + return max(accuracies)[1] + else: + return None + +def _get_metrics(component): + if component == "parser": + return ("las", "uas", "token_acc") + elif component == "tagger": + return ("tags_acc",) + elif component == "ner": + return ("ents_f", "ents_p", "ents_r") + return ("token_acc",) def _render_parses(i, to_render):