mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 02:36:32 +03:00
Add BiRNN for entailment
Hastily add bidirectional RNN to entailment example
This commit is contained in:
parent
f123f92e0c
commit
ca996fc01a
|
@ -12,7 +12,7 @@ from keras.optimizers import Adam
|
||||||
from keras.layers.normalization import BatchNormalization
|
from keras.layers.normalization import BatchNormalization
|
||||||
|
|
||||||
|
|
||||||
def build_model(vectors, shape, settings):
|
def build_model(vectors, shape, settings, use_rnn_encoding=False):
|
||||||
'''Compile the model.'''
|
'''Compile the model.'''
|
||||||
max_length, nr_hidden, nr_class = shape
|
max_length, nr_hidden, nr_class = shape
|
||||||
# Declare inputs.
|
# Declare inputs.
|
||||||
|
@ -21,6 +21,8 @@ def build_model(vectors, shape, settings):
|
||||||
|
|
||||||
# Construct operations, which we'll chain together.
|
# Construct operations, which we'll chain together.
|
||||||
embed = _StaticEmbedding(vectors, max_length, nr_hidden)
|
embed = _StaticEmbedding(vectors, max_length, nr_hidden)
|
||||||
|
if use_rnn_encoding:
|
||||||
|
encode = _BiLSTMEncode(max_length, nr_hidden)
|
||||||
attend = _Attention(max_length, nr_hidden)
|
attend = _Attention(max_length, nr_hidden)
|
||||||
align = _SoftAlignment(max_length, nr_hidden)
|
align = _SoftAlignment(max_length, nr_hidden)
|
||||||
compare = _Comparison(max_length, nr_hidden)
|
compare = _Comparison(max_length, nr_hidden)
|
||||||
|
@ -30,6 +32,10 @@ def build_model(vectors, shape, settings):
|
||||||
sent1 = embed(ids1) # Shape: (i, n)
|
sent1 = embed(ids1) # Shape: (i, n)
|
||||||
sent2 = embed(ids2) # Shape: (j, n)
|
sent2 = embed(ids2) # Shape: (j, n)
|
||||||
|
|
||||||
|
if use_rnn_encoding:
|
||||||
|
sent1 = encode(sent1)
|
||||||
|
sent2 = encode(sent2)
|
||||||
|
|
||||||
attention = attend(sent1, sent2) # Shape: (i, j)
|
attention = attend(sent1, sent2) # Shape: (i, j)
|
||||||
|
|
||||||
align1 = align(sent2, attention)
|
align1 = align(sent2, attention)
|
||||||
|
@ -73,6 +79,13 @@ class _StaticEmbedding(object):
|
||||||
def __call__(self, sentence):
|
def __call__(self, sentence):
|
||||||
return self.project(self.embed(sentence))
|
return self.project(self.embed(sentence))
|
||||||
|
|
||||||
|
class _BiRNNEncoding(object):
|
||||||
|
def __init__(self, max_length, nr_out):
|
||||||
|
self.model = Sequential()
|
||||||
|
self.model.add(Bidirectional(LSTM(nr_out, input_length=max_length)))
|
||||||
|
|
||||||
|
def __call__(self, sentence):
|
||||||
|
return self.model(sentence)
|
||||||
|
|
||||||
class _Attention(object):
|
class _Attention(object):
|
||||||
def __init__(self, max_length, nr_hidden, dropout=0.0, L2=1e-4, activation='relu'):
|
def __init__(self, max_length, nr_hidden, dropout=0.0, L2=1e-4, activation='relu'):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user