mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	Collate best model after training
This commit is contained in:
		
							parent
							
								
									5435b071b9
								
							
						
					
					
						commit
						2c80b7c013
					
				|  | @ -185,6 +185,52 @@ def train(lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0, | |||
|         with nlp.use_params(optimizer.averages): | ||||
|             final_model_path = output_path / 'model-final' | ||||
|             nlp.to_disk(final_model_path) | ||||
|         components = [] | ||||
|         if not no_parser: | ||||
|             components.append('parser') | ||||
|         if not no_tagger: | ||||
|             components.append('tagger') | ||||
|         if not no_entity: | ||||
|             components.append('ner') | ||||
|         _collate_best_model(meta, output_path, components) | ||||
| 
 | ||||
| def _collate_best_model(meta, output_path, components): | ||||
|     bests = {} | ||||
|     for component in components: | ||||
|         bests[component] = _find_best(output_path, component) | ||||
|     best_dest = output_path / 'model-best' | ||||
|     shutil.copytree(output_path / 'model-final', best_dest) | ||||
|     for component, best_component_src in bests.items(): | ||||
|         shutil.rmtree(best_dir / component) | ||||
|         shutil.copytree(best_component_src, best_dest / component) | ||||
|         with (best_component_src / 'accuracy.json').open() as file_: | ||||
|             accs = json.load(file_) | ||||
|         for metric in _get_metrics(component): | ||||
|             meta['accuracy'][metric] = accs[metric] | ||||
|     with (best_dest / 'meta.json').open('w') as file_: | ||||
|         file_.write(json_dumps(meta)) | ||||
| 
 | ||||
| 
 | ||||
| def _find_best(experiment_dir, component): | ||||
|     accuracies = [] | ||||
|     for epoch_model in experiment_dir.iterdir(): | ||||
|         if epoch_model.is_dir() and epoch_model.parts[-1] != "model-final": | ||||
|             accs = json.load((epoch_model / "accuracy.json").open()) | ||||
|             scores = [accs.get(metric, 0.0) for metric in _get_metrics(component)] | ||||
|             accuracies.append((scores, epoch_model)) | ||||
|     if accuracies: | ||||
|         return max(accuracies)[1] | ||||
|     else: | ||||
|         return None | ||||
| 
 | ||||
| def _get_metrics(component): | ||||
|     if component == "parser": | ||||
|         return ("las", "uas", "token_acc") | ||||
|     elif component == "tagger": | ||||
|         return ("tags_acc",) | ||||
|     elif component == "ner": | ||||
|         return ("ents_f", "ents_p", "ents_r") | ||||
|     return ("token_acc",) | ||||
| 
 | ||||
| 
 | ||||
| def _render_parses(i, to_render): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user