mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	Add eval_beam_widths argument to spacy train
This commit is contained in:
		
							parent
							
								
									b13b2aeb54
								
							
						
					
					
						commit
						daa8c3787a
					
				|  | @ -58,6 +58,7 @@ from .. import about | |||
|         str, | ||||
|     ), | ||||
|     noise_level=("Amount of corruption for data augmentation", "option", "nl", float), | ||||
|     eval_beam_widths=("Beam widths to evaluate, e.g. 4,8", "option", "bw", str), | ||||
|     gold_preproc=("Use gold preprocessing", "flag", "G", bool), | ||||
|     learn_tokens=("Make parser learn gold-standard tokenization", "flag", "T", bool), | ||||
|     verbose=("Display more information for debug", "flag", "VV", bool), | ||||
|  | @ -81,6 +82,7 @@ def train( | |||
|     parser_multitasks="", | ||||
|     entity_multitasks="", | ||||
|     noise_level=0.0, | ||||
|     eval_beam_widths="", | ||||
|     gold_preproc=False, | ||||
|     learn_tokens=False, | ||||
|     verbose=False, | ||||
|  | @ -134,6 +136,14 @@ def train( | |||
|         util.env_opt("batch_compound", 1.001), | ||||
|     ) | ||||
| 
 | ||||
|     if not eval_beam_widths: | ||||
|         eval_beam_widths = [1] | ||||
|     else: | ||||
|         eval_beam_widths = [int(bw) for bw in eval_beam_widths.split(",")] | ||||
|         if 1 not in eval_beam_widths: | ||||
|             eval_beam_widths.append(1) | ||||
|         eval_beam_widths.sort() | ||||
| 
 | ||||
|     # Set up the base model and pipeline. If a base model is specified, load | ||||
|     # the model and make sure the pipeline matches the pipeline setting. If | ||||
|     # training starts from a blank model, intitalize the language class. | ||||
|  | @ -247,7 +257,7 @@ def train( | |||
|                 epoch_model_path = output_path / ("model%d" % i) | ||||
|                 nlp.to_disk(epoch_model_path) | ||||
|                 nlp_loaded = util.load_model_from_path(epoch_model_path) | ||||
|                 for beam_width in [1, 4, 16, 128]: | ||||
|                 for beam_width in eval_beam_widths: | ||||
|                     for name, component in nlp_loaded.pipeline: | ||||
|                         if hasattr(component, "cfg"): | ||||
|                             component.cfg["beam_width"] = beam_width | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user