mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 09:14:32 +03:00
Collate best model after training
This commit is contained in:
parent
5435b071b9
commit
2c80b7c013
|
@ -185,6 +185,52 @@ def train(lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
|
|||
with nlp.use_params(optimizer.averages):
|
||||
final_model_path = output_path / 'model-final'
|
||||
nlp.to_disk(final_model_path)
|
||||
components = []
|
||||
if not no_parser:
|
||||
components.append('parser')
|
||||
if not no_tagger:
|
||||
components.append('tagger')
|
||||
if not no_entity:
|
||||
components.append('ner')
|
||||
_collate_best_model(meta, output_path, components)
|
||||
|
||||
def _collate_best_model(meta, output_path, components):
|
||||
bests = {}
|
||||
for component in components:
|
||||
bests[component] = _find_best(output_path, component)
|
||||
best_dest = output_path / 'model-best'
|
||||
shutil.copytree(output_path / 'model-final', best_dest)
|
||||
for component, best_component_src in bests.items():
|
||||
shutil.rmtree(best_dir / component)
|
||||
shutil.copytree(best_component_src, best_dest / component)
|
||||
with (best_component_src / 'accuracy.json').open() as file_:
|
||||
accs = json.load(file_)
|
||||
for metric in _get_metrics(component):
|
||||
meta['accuracy'][metric] = accs[metric]
|
||||
with (best_dest / 'meta.json').open('w') as file_:
|
||||
file_.write(json_dumps(meta))
|
||||
|
||||
|
||||
def _find_best(experiment_dir, component):
|
||||
accuracies = []
|
||||
for epoch_model in experiment_dir.iterdir():
|
||||
if epoch_model.is_dir() and epoch_model.parts[-1] != "model-final":
|
||||
accs = json.load((epoch_model / "accuracy.json").open())
|
||||
scores = [accs.get(metric, 0.0) for metric in _get_metrics(component)]
|
||||
accuracies.append((scores, epoch_model))
|
||||
if accuracies:
|
||||
return max(accuracies)[1]
|
||||
else:
|
||||
return None
|
||||
|
||||
def _get_metrics(component):
|
||||
if component == "parser":
|
||||
return ("las", "uas", "token_acc")
|
||||
elif component == "tagger":
|
||||
return ("tags_acc",)
|
||||
elif component == "ner":
|
||||
return ("ents_f", "ents_p", "ents_r")
|
||||
return ("token_acc",)
|
||||
|
||||
|
||||
def _render_parses(i, to_render):
|
||||
|
|
Loading…
Reference in New Issue
Block a user