Merge pull request #955 from kumaranvpl/fix_keras_parikh_entailment_bugs

Fix keras_parikh_entailment example bugs
This commit is contained in:
Matthew Honnibal 2017-04-07 14:59:57 +02:00 committed by GitHub
commit a5538d93d0
3 changed files with 4 additions and 4 deletions

View File

@ -78,7 +78,7 @@ You can run the `keras_parikh_entailment/` directory as a script, which executes
[`keras_parikh_entailment/__main__.py`](__main__.py). The first thing you'll want to do is train the model: [`keras_parikh_entailment/__main__.py`](__main__.py). The first thing you'll want to do is train the model:
```bash ```bash
python keras_parikh_entailment/ train <your_model_dir> <train_directory> <dev_directory> python keras_parikh_entailment/ train <train_directory> <dev_directory>
``` ```
Training takes about 300 epochs for full accuracy, and I haven't rerun the full Training takes about 300 epochs for full accuracy, and I haven't rerun the full

View File

@ -52,7 +52,7 @@ def train(train_loc, dev_loc, shape, settings):
file_.write(model.to_json()) file_.write(model.to_json())
def evaluate(model_dir, dev_loc): def evaluate(dev_loc):
dev_texts1, dev_texts2, dev_labels = read_snli(dev_loc) dev_texts1, dev_texts2, dev_labels = read_snli(dev_loc)
nlp = spacy.load('en', nlp = spacy.load('en',
create_pipeline=create_similarity_pipeline) create_pipeline=create_similarity_pipeline)

View File

@ -80,10 +80,10 @@ def get_word_ids(docs, rnn_encode=False, tree_truncate=False, max_length=100, nr
return Xs return Xs
def create_similarity_pipeline(nlp): def create_similarity_pipeline(nlp, max_length=100):
return [ return [
nlp.tagger, nlp.tagger,
nlp.entity, nlp.entity,
nlp.parser, nlp.parser,
KerasSimilarityShim.load(nlp.path / 'similarity', nlp, max_length=10) KerasSimilarityShim.load(nlp.path / 'similarity', nlp, max_length)
] ]