mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	Remove unused model_dir option
As noted in #845, the `model_dir` argument was not being used. I've removed it for now, although it would be good to have this option restored and working.
This commit is contained in:
		
							parent
							
								
									724e51ed47
								
							
						
					
					
						commit
						c031c677cc
					
				|  | @ -18,7 +18,7 @@ except ImportError: | ||||||
|     import pickle |     import pickle | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def train(model_dir, train_loc, dev_loc, shape, settings): | def train(train_loc, dev_loc, shape, settings): | ||||||
|     train_texts1, train_texts2, train_labels = read_snli(train_loc) |     train_texts1, train_texts2, train_labels = read_snli(train_loc) | ||||||
|     dev_texts1, dev_texts2, dev_labels = read_snli(dev_loc) |     dev_texts1, dev_texts2, dev_labels = read_snli(dev_loc) | ||||||
| 
 | 
 | ||||||
|  | @ -44,7 +44,7 @@ def train(model_dir, train_loc, dev_loc, shape, settings): | ||||||
|         batch_size=settings['batch_size']) |         batch_size=settings['batch_size']) | ||||||
|     if not (nlp.path / 'similarity').exists(): |     if not (nlp.path / 'similarity').exists(): | ||||||
|         (nlp.path / 'similarity').mkdir() |         (nlp.path / 'similarity').mkdir() | ||||||
|     print("Saving to", model_dir / 'similarity') |     print("Saving to", nlp.path / 'similarity') | ||||||
|     weights = model.get_weights() |     weights = model.get_weights() | ||||||
|     with (nlp.path / 'similarity' / 'model').open('wb') as file_: |     with (nlp.path / 'similarity' / 'model').open('wb') as file_: | ||||||
|         pickle.dump(weights[1:], file_) |         pickle.dump(weights[1:], file_) | ||||||
|  | @ -68,8 +68,8 @@ def evaluate(model_dir, dev_loc): | ||||||
|     return correct, total |     return correct, total | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def demo(model_dir): | def demo(): | ||||||
|     nlp = spacy.load('en', path=model_dir, |     nlp = spacy.load('en', | ||||||
|             create_pipeline=create_similarity_pipeline) |             create_pipeline=create_similarity_pipeline) | ||||||
|     doc1 = nlp(u'What were the best crime fiction books in 2016?') |     doc1 = nlp(u'What were the best crime fiction books in 2016?') | ||||||
|     doc2 = nlp( |     doc2 = nlp( | ||||||
|  | @ -98,7 +98,6 @@ def read_snli(path): | ||||||
| 
 | 
 | ||||||
| @plac.annotations( | @plac.annotations( | ||||||
|     mode=("Mode to execute", "positional", None, str, ["train", "evaluate", "demo"]), |     mode=("Mode to execute", "positional", None, str, ["train", "evaluate", "demo"]), | ||||||
|     model_dir=("Path to spaCy model directory", "positional", None, Path), |  | ||||||
|     train_loc=("Path to training data", "positional", None, Path), |     train_loc=("Path to training data", "positional", None, Path), | ||||||
|     dev_loc=("Path to development data", "positional", None, Path), |     dev_loc=("Path to development data", "positional", None, Path), | ||||||
|     max_length=("Length to truncate sentences", "option", "L", int), |     max_length=("Length to truncate sentences", "option", "L", int), | ||||||
|  | @ -110,7 +109,7 @@ def read_snli(path): | ||||||
|     tree_truncate=("Truncate sentences by tree distance", "flag", "T", bool), |     tree_truncate=("Truncate sentences by tree distance", "flag", "T", bool), | ||||||
|     gru_encode=("Encode sentences with bidirectional GRU", "flag", "E", bool), |     gru_encode=("Encode sentences with bidirectional GRU", "flag", "E", bool), | ||||||
| ) | ) | ||||||
| def main(mode, model_dir, train_loc, dev_loc, | def main(mode, train_loc, dev_loc, | ||||||
|         tree_truncate=False, |         tree_truncate=False, | ||||||
|         gru_encode=False, |         gru_encode=False, | ||||||
|         max_length=100, |         max_length=100, | ||||||
|  | @ -129,12 +128,12 @@ def main(mode, model_dir, train_loc, dev_loc, | ||||||
|         'gru_encode': gru_encode |         'gru_encode': gru_encode | ||||||
|     } |     } | ||||||
|     if mode == 'train': |     if mode == 'train': | ||||||
|         train(model_dir, train_loc, dev_loc, shape, settings) |         train(train_loc, dev_loc, shape, settings) | ||||||
|     elif mode == 'evaluate': |     elif mode == 'evaluate': | ||||||
|         correct, total = evaluate(model_dir, dev_loc) |         correct, total = evaluate(dev_loc) | ||||||
|         print(correct, '/', total, correct / total) |         print(correct, '/', total, correct / total) | ||||||
|     else: |     else: | ||||||
|         demo(model_dir) |         demo() | ||||||
| 
 | 
 | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     plac.call(main) |     plac.call(main) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user