mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	Only add beam width if customised
This commit is contained in:
		
							parent
							
								
									7a354761c7
								
							
						
					
					
						commit
						c94742ff64
					
				|  | @ -143,6 +143,7 @@ def train( | ||||||
|         if 1 not in eval_beam_widths: |         if 1 not in eval_beam_widths: | ||||||
|             eval_beam_widths.append(1) |             eval_beam_widths.append(1) | ||||||
|         eval_beam_widths.sort() |         eval_beam_widths.sort() | ||||||
|  |     has_beam_widths = eval_beam_widths != [1] | ||||||
| 
 | 
 | ||||||
|     # Set up the base model and pipeline. If a base model is specified, load |     # Set up the base model and pipeline. If a base model is specified, load | ||||||
|     # the model and make sure the pipeline matches the pipeline setting. If |     # the model and make sure the pipeline matches the pipeline setting. If | ||||||
|  | @ -211,11 +212,11 @@ def train( | ||||||
| 
 | 
 | ||||||
|     # fmt: off |     # fmt: off | ||||||
|     row_head = ("Itn", "Beam Width", "Dep Loss", "NER Loss", "UAS", "NER P", "NER R", "NER F", "Tag %", "Token %", "CPU WPS", "GPU WPS") |     row_head = ("Itn", "Beam Width", "Dep Loss", "NER Loss", "UAS", "NER P", "NER R", "NER F", "Tag %", "Token %", "CPU WPS", "GPU WPS") | ||||||
|     row_settings = { |     row_widths = (3, 10, 10, 7, 7, 7, 7, 7, 7, 7, 7) | ||||||
|         "widths": (3, 10, 10, 10, 7, 7, 7, 7, 7, 7, 7, 7), |     if has_beam_widths: | ||||||
|         "aligns": tuple(["r" for i in row_head]), |         row_head.insert(1, "Beam W.") | ||||||
|         "spacing": 2 |         row_widths.insert(1, 7) | ||||||
|     } |     row_settings = {"widths": row_widths, "aligns": tuple(["r" for i in row_head]), "spacing": 2} | ||||||
|     # fmt: on |     # fmt: on | ||||||
|     print("") |     print("") | ||||||
|     msg.row(row_head, **row_settings) |     msg.row(row_head, **row_settings) | ||||||
|  | @ -318,8 +319,11 @@ def train( | ||||||
|                     srsly.write_json(meta_loc, meta) |                     srsly.write_json(meta_loc, meta) | ||||||
|                     util.set_env_log(verbose) |                     util.set_env_log(verbose) | ||||||
| 
 | 
 | ||||||
|  |                     progress_args = [i, losses, scorer.scores] | ||||||
|  |                     if has_beam_widths: | ||||||
|  |                         progress_args.inset(1, beam_with) | ||||||
|                     progress = _get_progress( |                     progress = _get_progress( | ||||||
|                         i, beam_width, losses, scorer.scores, cpu_wps=cpu_wps, gpu_wps=gpu_wps |                         *progress_args, cpu_wps=cpu_wps, gpu_wps=gpu_wps | ||||||
|                     ) |                     ) | ||||||
|                     msg.row(progress, **row_settings) |                     msg.row(progress, **row_settings) | ||||||
|     finally: |     finally: | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user