diff --git a/examples/keras_parikh_entailment/__main__.py b/examples/keras_parikh_entailment/__main__.py index 927120f3c..5c3132bab 100644 --- a/examples/keras_parikh_entailment/__main__.py +++ b/examples/keras_parikh_entailment/__main__.py @@ -18,7 +18,7 @@ except ImportError: import pickle -def train(model_dir, train_loc, dev_loc, shape, settings): +def train(train_loc, dev_loc, shape, settings): train_texts1, train_texts2, train_labels = read_snli(train_loc) dev_texts1, dev_texts2, dev_labels = read_snli(dev_loc) @@ -44,7 +44,7 @@ def train(model_dir, train_loc, dev_loc, shape, settings): batch_size=settings['batch_size']) if not (nlp.path / 'similarity').exists(): (nlp.path / 'similarity').mkdir() - print("Saving to", model_dir / 'similarity') + print("Saving to", nlp.path / 'similarity') weights = model.get_weights() with (nlp.path / 'similarity' / 'model').open('wb') as file_: pickle.dump(weights[1:], file_) @@ -68,8 +68,8 @@ def evaluate(model_dir, dev_loc): return correct, total -def demo(model_dir): - nlp = spacy.load('en', path=model_dir, +def demo(): + nlp = spacy.load('en', create_pipeline=create_similarity_pipeline) doc1 = nlp(u'What were the best crime fiction books in 2016?') doc2 = nlp( @@ -98,7 +98,6 @@ def read_snli(path): @plac.annotations( mode=("Mode to execute", "positional", None, str, ["train", "evaluate", "demo"]), - model_dir=("Path to spaCy model directory", "positional", None, Path), train_loc=("Path to training data", "positional", None, Path), dev_loc=("Path to development data", "positional", None, Path), max_length=("Length to truncate sentences", "option", "L", int), @@ -110,7 +109,7 @@ def read_snli(path): tree_truncate=("Truncate sentences by tree distance", "flag", "T", bool), gru_encode=("Encode sentences with bidirectional GRU", "flag", "E", bool), ) -def main(mode, model_dir, train_loc, dev_loc, +def main(mode, train_loc, dev_loc, tree_truncate=False, gru_encode=False, max_length=100, @@ -129,12 +128,12 @@ def main(mode, model_dir, train_loc, dev_loc, 'gru_encode': gru_encode } if mode == 'train': - train(model_dir, train_loc, dev_loc, shape, settings) + train(train_loc, dev_loc, shape, settings) elif mode == 'evaluate': - correct, total = evaluate(model_dir, dev_loc) + correct, total = evaluate(dev_loc) print(correct, '/', total, correct / total) else: - demo(model_dir) + demo() if __name__ == '__main__': plac.call(main)