mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	Add gold_preproc flag to cli/train
This commit is contained in:
		
							parent
							
								
									42fa84075f
								
							
						
					
					
						commit
						84bb543e4d
					
				|  | @ -32,10 +32,12 @@ from ..compat import json_dumps | |||
|     resume=("Whether to resume training", "flag", "R", bool), | ||||
|     no_tagger=("Don't train tagger", "flag", "T", bool), | ||||
|     no_parser=("Don't train parser", "flag", "P", bool), | ||||
|     no_entities=("Don't train NER", "flag", "N", bool) | ||||
|     no_entities=("Don't train NER", "flag", "N", bool), | ||||
|     gold_preproc=("Use gold preprocessing", "flag", "G", bool), | ||||
| ) | ||||
| def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, | ||||
|           use_gpu=-1, resume=False, no_tagger=False, no_parser=False, no_entities=False): | ||||
|           use_gpu=-1, resume=False, no_tagger=False, no_parser=False, no_entities=False, | ||||
|           gold_preproc=False): | ||||
|     """ | ||||
|     Train a model. Expects data in spaCy's JSON format. | ||||
|     """ | ||||
|  | @ -86,7 +88,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, | |||
|                 i += 20 | ||||
|             with tqdm.tqdm(total=n_train_words, leave=False) as pbar: | ||||
|                 train_docs = corpus.train_docs(nlp, projectivize=True, | ||||
|                                                gold_preproc=False, max_length=0) | ||||
|                                                gold_preproc=gold_preproc, max_length=0) | ||||
|                 losses = {} | ||||
|                 for batch in minibatch(train_docs, size=batch_sizes): | ||||
|                     docs, golds = zip(*batch) | ||||
|  | @ -104,7 +106,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, | |||
|                 scorer = nlp_loaded.evaluate( | ||||
|                             corpus.dev_docs( | ||||
|                                 nlp_loaded, | ||||
|                                 gold_preproc=False)) | ||||
|                                 gold_preproc=gold_preproc)) | ||||
|                 acc_loc =(output_path / ('model%d' % i) / 'accuracy.json') | ||||
|                 with acc_loc.open('w') as file_: | ||||
|                     file_.write(json_dumps(scorer.scores)) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user