diff --git a/examples/keras_parikh_entailment/__main__.py b/examples/keras_parikh_entailment/__main__.py index 35553af43..ede6c9103 100644 --- a/examples/keras_parikh_entailment/__main__.py +++ b/examples/keras_parikh_entailment/__main__.py @@ -7,6 +7,7 @@ from pathlib import Path from spacy_hook import get_embeddings, get_word_ids from spacy_hook import create_similarity_pipeline +from keras_decomposable_attention import build_model def train(model_dir, train_loc, dev_loc, shape, settings): print("Loading spaCy") diff --git a/examples/keras_parikh_entailment/keras_decomposable_attention.py b/examples/keras_parikh_entailment/keras_decomposable_attention.py index a69d3ab7e..ede435f42 100644 --- a/examples/keras_parikh_entailment/keras_decomposable_attention.py +++ b/examples/keras_parikh_entailment/keras_decomposable_attention.py @@ -101,10 +101,11 @@ class _Attention(object): self.model = TimeDistributed(self.model) def __call__(self, sent1, sent2): - def _outer((A, B)): - att_ji = T.batched_dot(B, A.dimshuffle((0, 2, 1))) + def _outer(AB): + att_ji = T.batched_dot(AB[1], AB[0].dimshuffle((0, 2, 1))) return att_ji.dimshuffle((0, 2, 1)) + return merge( [self.model(sent1), self.model(sent2)], mode=_outer, @@ -117,7 +118,9 @@ class _SoftAlignment(object): self.nr_hidden = nr_hidden def __call__(self, sentence, attention, transpose=False): - def _normalize_attention((att, mat)): + def _normalize_attention(attmat): + att = attmat[0] + mat = attmat[1] if transpose: att = att.dimshuffle((0, 2, 1)) # 3d softmax