mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	Add gold_preproc flag to cli/train
This commit is contained in:
		
							parent
							
								
									42fa84075f
								
							
						
					
					
						commit
						84bb543e4d
					
				|  | @ -32,10 +32,12 @@ from ..compat import json_dumps | ||||||
|     resume=("Whether to resume training", "flag", "R", bool), |     resume=("Whether to resume training", "flag", "R", bool), | ||||||
|     no_tagger=("Don't train tagger", "flag", "T", bool), |     no_tagger=("Don't train tagger", "flag", "T", bool), | ||||||
|     no_parser=("Don't train parser", "flag", "P", bool), |     no_parser=("Don't train parser", "flag", "P", bool), | ||||||
|     no_entities=("Don't train NER", "flag", "N", bool) |     no_entities=("Don't train NER", "flag", "N", bool), | ||||||
|  |     gold_preproc=("Use gold preprocessing", "flag", "G", bool), | ||||||
| ) | ) | ||||||
| def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, | def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, | ||||||
|           use_gpu=-1, resume=False, no_tagger=False, no_parser=False, no_entities=False): |           use_gpu=-1, resume=False, no_tagger=False, no_parser=False, no_entities=False, | ||||||
|  |           gold_preproc=False): | ||||||
|     """ |     """ | ||||||
|     Train a model. Expects data in spaCy's JSON format. |     Train a model. Expects data in spaCy's JSON format. | ||||||
|     """ |     """ | ||||||
|  | @ -86,7 +88,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, | ||||||
|                 i += 20 |                 i += 20 | ||||||
|             with tqdm.tqdm(total=n_train_words, leave=False) as pbar: |             with tqdm.tqdm(total=n_train_words, leave=False) as pbar: | ||||||
|                 train_docs = corpus.train_docs(nlp, projectivize=True, |                 train_docs = corpus.train_docs(nlp, projectivize=True, | ||||||
|                                                gold_preproc=False, max_length=0) |                                                gold_preproc=gold_preproc, max_length=0) | ||||||
|                 losses = {} |                 losses = {} | ||||||
|                 for batch in minibatch(train_docs, size=batch_sizes): |                 for batch in minibatch(train_docs, size=batch_sizes): | ||||||
|                     docs, golds = zip(*batch) |                     docs, golds = zip(*batch) | ||||||
|  | @ -104,7 +106,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, | ||||||
|                 scorer = nlp_loaded.evaluate( |                 scorer = nlp_loaded.evaluate( | ||||||
|                             corpus.dev_docs( |                             corpus.dev_docs( | ||||||
|                                 nlp_loaded, |                                 nlp_loaded, | ||||||
|                                 gold_preproc=False)) |                                 gold_preproc=gold_preproc)) | ||||||
|                 acc_loc =(output_path / ('model%d' % i) / 'accuracy.json') |                 acc_loc =(output_path / ('model%d' % i) / 'accuracy.json') | ||||||
|                 with acc_loc.open('w') as file_: |                 with acc_loc.open('w') as file_: | ||||||
|                     file_.write(json_dumps(scorer.scores)) |                     file_.write(json_dumps(scorer.scores)) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user