Merge pull request #955 from kumaranvpl/fix_keras_parikh_entailment_bugs

Fix keras_parikh_entailment example bugs
This commit is contained in:
Matthew Honnibal 2017-04-07 14:59:57 +02:00 committed by GitHub
commit a5538d93d0
3 changed files with 4 additions and 4 deletions

View File

@ -78,7 +78,7 @@ You can run the `keras_parikh_entailment/` directory as a script, which executes
[`keras_parikh_entailment/__main__.py`](__main__.py). The first thing you'll want to do is train the model:
```bash
python keras_parikh_entailment/ train <your_model_dir> <train_directory> <dev_directory>
python keras_parikh_entailment/ train <train_directory> <dev_directory>
```
Training takes about 300 epochs for full accuracy, and I haven't rerun the full

View File

@ -52,7 +52,7 @@ def train(train_loc, dev_loc, shape, settings):
file_.write(model.to_json())
def evaluate(model_dir, dev_loc):
def evaluate(dev_loc):
dev_texts1, dev_texts2, dev_labels = read_snli(dev_loc)
nlp = spacy.load('en',
create_pipeline=create_similarity_pipeline)

View File

@ -80,10 +80,10 @@ def get_word_ids(docs, rnn_encode=False, tree_truncate=False, max_length=100, nr
return Xs
def create_similarity_pipeline(nlp):
def create_similarity_pipeline(nlp, max_length=100):
return [
nlp.tagger,
nlp.entity,
nlp.parser,
KerasSimilarityShim.load(nlp.path / 'similarity', nlp, max_length=10)
KerasSimilarityShim.load(nlp.path / 'similarity', nlp, max_length)
]