From ca996fc01a874894e20865ba0e2256a0d8b334f7 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sat, 12 Nov 2016 01:15:01 +1100 Subject: [PATCH] Add BiRNN for entailment Hastily add bidirectional RNN to entailment example --- .../keras_decomposable_attention.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/examples/keras_parikh_entailment/keras_decomposable_attention.py b/examples/keras_parikh_entailment/keras_decomposable_attention.py index 21ecda447..a69d3ab7e 100644 --- a/examples/keras_parikh_entailment/keras_decomposable_attention.py +++ b/examples/keras_parikh_entailment/keras_decomposable_attention.py @@ -12,7 +12,7 @@ from keras.optimizers import Adam from keras.layers.normalization import BatchNormalization -def build_model(vectors, shape, settings): +def build_model(vectors, shape, settings, use_rnn_encoding=False): '''Compile the model.''' max_length, nr_hidden, nr_class = shape # Declare inputs. @@ -21,6 +21,8 @@ def build_model(vectors, shape, settings): # Construct operations, which we'll chain together. embed = _StaticEmbedding(vectors, max_length, nr_hidden) + if use_rnn_encoding: + encode = _BiLSTMEncode(max_length, nr_hidden) attend = _Attention(max_length, nr_hidden) align = _SoftAlignment(max_length, nr_hidden) compare = _Comparison(max_length, nr_hidden) @@ -29,6 +31,10 @@ def build_model(vectors, shape, settings): # Declare the model as a computational graph. sent1 = embed(ids1) # Shape: (i, n) sent2 = embed(ids2) # Shape: (j, n) + + if use_rnn_encoding: + sent1 = encode(sent1) + sent2 = encode(sent2) attention = attend(sent1, sent2) # Shape: (i, j) @@ -72,7 +78,14 @@ class _StaticEmbedding(object): def __call__(self, sentence): return self.project(self.embed(sentence)) + +class _BiRNNEncoding(object): + def __init__(self, max_length, nr_out): + self.model = Sequential() + self.model.add(Bidirectional(LSTM(nr_out, input_length=max_length))) + def __call__(self, sentence): + return self.model(sentence) class _Attention(object): def __init__(self, max_length, nr_hidden, dropout=0.0, L2=1e-4, activation='relu'):