mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Merge pull request #955 from kumaranvpl/fix_keras_parikh_entailment_bugs
Fix keras_parikh_entailment example bugs
This commit is contained in:
commit
a5538d93d0
|
@ -78,7 +78,7 @@ You can run the `keras_parikh_entailment/` directory as a script, which executes
|
|||
[`keras_parikh_entailment/__main__.py`](__main__.py). The first thing you'll want to do is train the model:
|
||||
|
||||
```bash
|
||||
python keras_parikh_entailment/ train <your_model_dir> <train_directory> <dev_directory>
|
||||
python keras_parikh_entailment/ train <train_directory> <dev_directory>
|
||||
```
|
||||
|
||||
Training takes about 300 epochs for full accuracy, and I haven't rerun the full
|
||||
|
|
|
@ -52,7 +52,7 @@ def train(train_loc, dev_loc, shape, settings):
|
|||
file_.write(model.to_json())
|
||||
|
||||
|
||||
def evaluate(model_dir, dev_loc):
|
||||
def evaluate(dev_loc):
|
||||
dev_texts1, dev_texts2, dev_labels = read_snli(dev_loc)
|
||||
nlp = spacy.load('en',
|
||||
create_pipeline=create_similarity_pipeline)
|
||||
|
|
|
@ -80,10 +80,10 @@ def get_word_ids(docs, rnn_encode=False, tree_truncate=False, max_length=100, nr
|
|||
return Xs
|
||||
|
||||
|
||||
def create_similarity_pipeline(nlp):
|
||||
def create_similarity_pipeline(nlp, max_length=100):
|
||||
return [
|
||||
nlp.tagger,
|
||||
nlp.entity,
|
||||
nlp.parser,
|
||||
KerasSimilarityShim.load(nlp.path / 'similarity', nlp, max_length=10)
|
||||
KerasSimilarityShim.load(nlp.path / 'similarity', nlp, max_length)
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue
Block a user