mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Add BiRNN for entailment
Hastily add bidirectional RNN to entailment example
This commit is contained in:
parent
f123f92e0c
commit
ca996fc01a
|
@ -12,7 +12,7 @@ from keras.optimizers import Adam
|
|||
from keras.layers.normalization import BatchNormalization
|
||||
|
||||
|
||||
def build_model(vectors, shape, settings):
|
||||
def build_model(vectors, shape, settings, use_rnn_encoding=False):
|
||||
'''Compile the model.'''
|
||||
max_length, nr_hidden, nr_class = shape
|
||||
# Declare inputs.
|
||||
|
@ -21,6 +21,8 @@ def build_model(vectors, shape, settings):
|
|||
|
||||
# Construct operations, which we'll chain together.
|
||||
embed = _StaticEmbedding(vectors, max_length, nr_hidden)
|
||||
if use_rnn_encoding:
|
||||
encode = _BiLSTMEncode(max_length, nr_hidden)
|
||||
attend = _Attention(max_length, nr_hidden)
|
||||
align = _SoftAlignment(max_length, nr_hidden)
|
||||
compare = _Comparison(max_length, nr_hidden)
|
||||
|
@ -29,6 +31,10 @@ def build_model(vectors, shape, settings):
|
|||
# Declare the model as a computational graph.
|
||||
sent1 = embed(ids1) # Shape: (i, n)
|
||||
sent2 = embed(ids2) # Shape: (j, n)
|
||||
|
||||
if use_rnn_encoding:
|
||||
sent1 = encode(sent1)
|
||||
sent2 = encode(sent2)
|
||||
|
||||
attention = attend(sent1, sent2) # Shape: (i, j)
|
||||
|
||||
|
@ -72,7 +78,14 @@ class _StaticEmbedding(object):
|
|||
|
||||
def __call__(self, sentence):
|
||||
return self.project(self.embed(sentence))
|
||||
|
||||
class _BiRNNEncoding(object):
|
||||
def __init__(self, max_length, nr_out):
|
||||
self.model = Sequential()
|
||||
self.model.add(Bidirectional(LSTM(nr_out, input_length=max_length)))
|
||||
|
||||
def __call__(self, sentence):
|
||||
return self.model(sentence)
|
||||
|
||||
class _Attention(object):
|
||||
def __init__(self, max_length, nr_hidden, dropout=0.0, L2=1e-4, activation='relu'):
|
||||
|
|
Loading…
Reference in New Issue
Block a user