Fix conflict

This commit is contained in:
Matthew Honnibal 2016-11-13 08:52:20 -06:00
parent 12a7b05360
commit 389e8b700e
3 changed files with 16 additions and 23 deletions

View File

@ -22,22 +22,13 @@ def train(model_dir, train_loc, dev_loc, shape, settings):
print("Compiling network") print("Compiling network")
model = build_model(get_embeddings(nlp.vocab), shape, settings) model = build_model(get_embeddings(nlp.vocab), shape, settings)
print("Processing texts...") print("Processing texts...")
train_X1 = get_word_ids(list(nlp.pipe(train_texts1, n_threads=10, batch_size=10000)), Xs = []
for texts in (train_texts1, train_texts2, dev_texts1, dev_texts2):
Xs.append(get_word_ids(list(nlp.pipe(texts, n_threads=20, batch_size=20000)),
max_length=shape[0], max_length=shape[0],
tree_truncate=settings['tree_truncate']) rnn_encode=settings['gru_encode'],
train_X2 = get_word_ids(list(nlp.pipe(train_texts2, n_threads=10, batch_size=10000)), tree_truncate=settings['tree_truncate']))
max_length=shape[0], train_X1, train_X2, dev_X1, dev_X2 = Xs
tree_truncate=settings['tree_truncate'])
dev_X1 = get_word_ids(list(nlp.pipe(dev_texts1, n_threads=10, batch_size=10000)),
max_length=shape[0],
tree_truncate=settings['tree_truncate'])
dev_X2 = get_word_ids(list(nlp.pipe(dev_texts2, n_threads=10, batch_size=10000)),
max_length=shape[0],
tree_truncate=settings['tree_truncate'])
print(train_X1.shape, train_X2.shape)
print(dev_X1.shape, dev_X2.shape)
print(train_labels.shape, dev_labels.shape)
print(settings) print(settings)
model.fit( model.fit(
[train_X1, train_X2], [train_X1, train_X2],
@ -103,7 +94,7 @@ def read_snli(path):
dropout=("Dropout level", "option", "d", float), dropout=("Dropout level", "option", "d", float),
learn_rate=("Learning rate", "option", "e", float), learn_rate=("Learning rate", "option", "e", float),
batch_size=("Batch size for neural network training", "option", "b", float), batch_size=("Batch size for neural network training", "option", "b", float),
nr_epoch=("Number of training epochs", "option", "i", float), nr_epoch=("Number of training epochs", "option", "i", int),
tree_truncate=("Truncate sentences by tree distance", "flag", "T", bool), tree_truncate=("Truncate sentences by tree distance", "flag", "T", bool),
gru_encode=("Encode sentences with bidirectional GRU", "flag", "E", bool), gru_encode=("Encode sentences with bidirectional GRU", "flag", "E", bool),
) )

View File

@ -107,8 +107,6 @@ class _Attention(object):
def _outer(AB): def _outer(AB):
att_ji = K.batch_dot(AB[1], K.permute_dimensions(AB[0], (0, 2, 1))) att_ji = K.batch_dot(AB[1], K.permute_dimensions(AB[0], (0, 2, 1)))
return K.permute_dimensions(att_ji,(0, 2, 1)) return K.permute_dimensions(att_ji,(0, 2, 1))
return merge( return merge(
[self.model(sent1), self.model(sent2)], [self.model(sent1), self.model(sent2)],
mode=_outer, mode=_outer,
@ -153,6 +151,7 @@ class _Comparison(object):
def __call__(self, sent, align, **kwargs): def __call__(self, sent, align, **kwargs):
result = self.model(merge([sent, align], mode='concat')) # Shape: (i, n) result = self.model(merge([sent, align], mode='concat')) # Shape: (i, n)
result = _GlobalSumPooling1D()(result, mask=self.words) result = _GlobalSumPooling1D()(result, mask=self.words)
result = BatchNormalization()(result)
return result return result

View File

@ -40,16 +40,19 @@ def get_embeddings(vocab):
return vectors return vectors
def get_word_ids(docs, tree_truncate=False, max_length=100): def get_word_ids(docs, rnn_encode=False, tree_truncate=False, max_length=100):
Xs = numpy.zeros((len(docs), max_length), dtype='int32') Xs = numpy.zeros((len(docs), max_length), dtype='int32')
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
j = 0 if tree_truncate:
queue = [sent.root for sent in doc.sents] queue = [sent.root for sent in doc.sents]
else:
queue = list(doc)
words = [] words = []
while len(words) <= max_length and queue: while len(words) <= max_length and queue:
word = queue.pop(0) word = queue.pop(0)
if word.has_vector and not word.is_punct and not word.is_space: if rnn_encode or (word.has_vector and not word.is_punct and not word.is_space):
words.append(word) words.append(word)
if tree_truncate:
queue.extend(list(word.lefts)) queue.extend(list(word.lefts))
queue.extend(list(word.rights)) queue.extend(list(word.rights))
words.sort() words.sort()