Remove unused model_dir option

As noted in #845, the `model_dir` argument was not being used. I've removed it for now, although it would be good to have this option restored and working.
This commit is contained in:
Matthew Honnibal 2017-02-18 10:38:22 +01:00 committed by GitHub
parent 724e51ed47
commit c031c677cc

View File

@ -18,7 +18,7 @@ except ImportError:
import pickle import pickle
def train(model_dir, train_loc, dev_loc, shape, settings): def train(train_loc, dev_loc, shape, settings):
train_texts1, train_texts2, train_labels = read_snli(train_loc) train_texts1, train_texts2, train_labels = read_snli(train_loc)
dev_texts1, dev_texts2, dev_labels = read_snli(dev_loc) dev_texts1, dev_texts2, dev_labels = read_snli(dev_loc)
@ -44,7 +44,7 @@ def train(model_dir, train_loc, dev_loc, shape, settings):
batch_size=settings['batch_size']) batch_size=settings['batch_size'])
if not (nlp.path / 'similarity').exists(): if not (nlp.path / 'similarity').exists():
(nlp.path / 'similarity').mkdir() (nlp.path / 'similarity').mkdir()
print("Saving to", model_dir / 'similarity') print("Saving to", nlp.path / 'similarity')
weights = model.get_weights() weights = model.get_weights()
with (nlp.path / 'similarity' / 'model').open('wb') as file_: with (nlp.path / 'similarity' / 'model').open('wb') as file_:
pickle.dump(weights[1:], file_) pickle.dump(weights[1:], file_)
@ -68,8 +68,8 @@ def evaluate(model_dir, dev_loc):
return correct, total return correct, total
def demo(model_dir): def demo():
nlp = spacy.load('en', path=model_dir, nlp = spacy.load('en',
create_pipeline=create_similarity_pipeline) create_pipeline=create_similarity_pipeline)
doc1 = nlp(u'What were the best crime fiction books in 2016?') doc1 = nlp(u'What were the best crime fiction books in 2016?')
doc2 = nlp( doc2 = nlp(
@ -98,7 +98,6 @@ def read_snli(path):
@plac.annotations( @plac.annotations(
mode=("Mode to execute", "positional", None, str, ["train", "evaluate", "demo"]), mode=("Mode to execute", "positional", None, str, ["train", "evaluate", "demo"]),
model_dir=("Path to spaCy model directory", "positional", None, Path),
train_loc=("Path to training data", "positional", None, Path), train_loc=("Path to training data", "positional", None, Path),
dev_loc=("Path to development data", "positional", None, Path), dev_loc=("Path to development data", "positional", None, Path),
max_length=("Length to truncate sentences", "option", "L", int), max_length=("Length to truncate sentences", "option", "L", int),
@ -110,7 +109,7 @@ def read_snli(path):
tree_truncate=("Truncate sentences by tree distance", "flag", "T", bool), tree_truncate=("Truncate sentences by tree distance", "flag", "T", bool),
gru_encode=("Encode sentences with bidirectional GRU", "flag", "E", bool), gru_encode=("Encode sentences with bidirectional GRU", "flag", "E", bool),
) )
def main(mode, model_dir, train_loc, dev_loc, def main(mode, train_loc, dev_loc,
tree_truncate=False, tree_truncate=False,
gru_encode=False, gru_encode=False,
max_length=100, max_length=100,
@ -129,12 +128,12 @@ def main(mode, model_dir, train_loc, dev_loc,
'gru_encode': gru_encode 'gru_encode': gru_encode
} }
if mode == 'train': if mode == 'train':
train(model_dir, train_loc, dev_loc, shape, settings) train(train_loc, dev_loc, shape, settings)
elif mode == 'evaluate': elif mode == 'evaluate':
correct, total = evaluate(model_dir, dev_loc) correct, total = evaluate(dev_loc)
print(correct, '/', total, correct / total) print(correct, '/', total, correct / total)
else: else:
demo(model_dir) demo()
if __name__ == '__main__': if __name__ == '__main__':
plac.call(main) plac.call(main)