2016-11-01 03:51:54 +03:00
|
|
|
from __future__ import division, unicode_literals, print_function
|
|
|
|
import spacy
|
|
|
|
|
|
|
|
import plac
|
|
|
|
from pathlib import Path
|
2016-11-12 20:43:37 +03:00
|
|
|
import ujson as json
|
|
|
|
import numpy
|
|
|
|
from keras.utils.np_utils import to_categorical
|
2016-11-01 03:51:54 +03:00
|
|
|
|
|
|
|
from spacy_hook import get_embeddings, get_word_ids
|
|
|
|
from spacy_hook import create_similarity_pipeline
|
|
|
|
|
2016-11-12 02:13:12 +03:00
|
|
|
from keras_decomposable_attention import build_model
|
2016-11-01 03:51:54 +03:00
|
|
|
|
2017-01-31 22:27:13 +03:00
|
|
|
try:
|
|
|
|
import cPickle as pickle
|
|
|
|
except ImportError:
|
|
|
|
import pickle
|
|
|
|
|
2016-11-12 20:43:37 +03:00
|
|
|
|
2017-02-18 12:38:22 +03:00
|
|
|
def train(train_loc, dev_loc, shape, settings):
|
2016-11-12 20:43:37 +03:00
|
|
|
train_texts1, train_texts2, train_labels = read_snli(train_loc)
|
|
|
|
dev_texts1, dev_texts2, dev_labels = read_snli(dev_loc)
|
2017-01-31 22:27:13 +03:00
|
|
|
|
2016-11-01 03:51:54 +03:00
|
|
|
print("Loading spaCy")
|
2016-11-12 20:43:37 +03:00
|
|
|
nlp = spacy.load('en')
|
2017-01-31 22:27:13 +03:00
|
|
|
assert nlp.path is not None
|
2016-11-01 03:51:54 +03:00
|
|
|
print("Compiling network")
|
|
|
|
model = build_model(get_embeddings(nlp.vocab), shape, settings)
|
|
|
|
print("Processing texts...")
|
2017-01-31 22:27:13 +03:00
|
|
|
Xs = []
|
2016-11-13 17:52:20 +03:00
|
|
|
for texts in (train_texts1, train_texts2, dev_texts1, dev_texts2):
|
|
|
|
Xs.append(get_word_ids(list(nlp.pipe(texts, n_threads=20, batch_size=20000)),
|
|
|
|
max_length=shape[0],
|
|
|
|
rnn_encode=settings['gru_encode'],
|
|
|
|
tree_truncate=settings['tree_truncate']))
|
|
|
|
train_X1, train_X2, dev_X1, dev_X2 = Xs
|
2016-11-12 20:43:37 +03:00
|
|
|
print(settings)
|
2016-11-01 03:51:54 +03:00
|
|
|
model.fit(
|
2016-11-12 20:43:37 +03:00
|
|
|
[train_X1, train_X2],
|
2016-11-01 03:51:54 +03:00
|
|
|
train_labels,
|
2016-11-12 20:43:37 +03:00
|
|
|
validation_data=([dev_X1, dev_X2], dev_labels),
|
2016-11-01 03:51:54 +03:00
|
|
|
nb_epoch=settings['nr_epoch'],
|
|
|
|
batch_size=settings['batch_size'])
|
2017-01-31 22:27:13 +03:00
|
|
|
if not (nlp.path / 'similarity').exists():
|
|
|
|
(nlp.path / 'similarity').mkdir()
|
2017-02-18 12:38:22 +03:00
|
|
|
print("Saving to", nlp.path / 'similarity')
|
2017-01-31 22:27:13 +03:00
|
|
|
weights = model.get_weights()
|
|
|
|
with (nlp.path / 'similarity' / 'model').open('wb') as file_:
|
|
|
|
pickle.dump(weights[1:], file_)
|
|
|
|
with (nlp.path / 'similarity' / 'config.json').open('wb') as file_:
|
|
|
|
file_.write(model.to_json())
|
2016-11-01 03:51:54 +03:00
|
|
|
|
|
|
|
|
2017-04-05 13:50:47 +03:00
|
|
|
def evaluate(dev_loc):
|
2017-01-31 22:27:13 +03:00
|
|
|
dev_texts1, dev_texts2, dev_labels = read_snli(dev_loc)
|
|
|
|
nlp = spacy.load('en',
|
2016-11-01 03:51:54 +03:00
|
|
|
create_pipeline=create_similarity_pipeline)
|
2017-01-31 22:27:13 +03:00
|
|
|
total = 0.
|
|
|
|
correct = 0.
|
|
|
|
for text1, text2, label in zip(dev_texts1, dev_texts2, dev_labels):
|
2016-11-01 03:51:54 +03:00
|
|
|
doc1 = nlp(text1)
|
|
|
|
doc2 = nlp(text2)
|
|
|
|
sim = doc1.similarity(doc2)
|
2017-01-31 22:27:13 +03:00
|
|
|
if sim.argmax() == label.argmax():
|
2016-11-01 03:51:54 +03:00
|
|
|
correct += 1
|
2017-01-31 22:27:13 +03:00
|
|
|
total += 1
|
2016-11-01 03:51:54 +03:00
|
|
|
return correct, total
|
|
|
|
|
|
|
|
|
2017-02-18 12:38:22 +03:00
|
|
|
def demo():
|
|
|
|
nlp = spacy.load('en',
|
2016-11-01 03:51:54 +03:00
|
|
|
create_pipeline=create_similarity_pipeline)
|
2017-01-31 22:27:13 +03:00
|
|
|
doc1 = nlp(u'What were the best crime fiction books in 2016?')
|
|
|
|
doc2 = nlp(
|
|
|
|
u'What should I read that was published last year? I like crime stories.')
|
|
|
|
print(doc1)
|
|
|
|
print(doc2)
|
|
|
|
print("Similarity", doc1.similarity(doc2))
|
2016-11-01 03:51:54 +03:00
|
|
|
|
|
|
|
|
|
|
|
LABELS = {'entailment': 0, 'contradiction': 1, 'neutral': 2}
|
2016-11-12 20:43:37 +03:00
|
|
|
def read_snli(path):
|
|
|
|
texts1 = []
|
|
|
|
texts2 = []
|
|
|
|
labels = []
|
|
|
|
with path.open() as file_:
|
2016-11-01 03:51:54 +03:00
|
|
|
for line in file_:
|
|
|
|
eg = json.loads(line)
|
|
|
|
label = eg['gold_label']
|
|
|
|
if label == '-':
|
|
|
|
continue
|
2016-11-12 20:43:37 +03:00
|
|
|
texts1.append(eg['sentence1'])
|
|
|
|
texts2.append(eg['sentence2'])
|
|
|
|
labels.append(LABELS[label])
|
|
|
|
return texts1, texts2, to_categorical(numpy.asarray(labels, dtype='int32'))
|
2016-11-01 03:51:54 +03:00
|
|
|
|
|
|
|
|
|
|
|
@plac.annotations(
|
|
|
|
mode=("Mode to execute", "positional", None, str, ["train", "evaluate", "demo"]),
|
|
|
|
train_loc=("Path to training data", "positional", None, Path),
|
|
|
|
dev_loc=("Path to development data", "positional", None, Path),
|
|
|
|
max_length=("Length to truncate sentences", "option", "L", int),
|
|
|
|
nr_hidden=("Number of hidden units", "option", "H", int),
|
|
|
|
dropout=("Dropout level", "option", "d", float),
|
|
|
|
learn_rate=("Learning rate", "option", "e", float),
|
2016-11-18 15:32:12 +03:00
|
|
|
batch_size=("Batch size for neural network training", "option", "b", int),
|
2016-11-13 17:52:20 +03:00
|
|
|
nr_epoch=("Number of training epochs", "option", "i", int),
|
2016-11-12 20:43:37 +03:00
|
|
|
tree_truncate=("Truncate sentences by tree distance", "flag", "T", bool),
|
|
|
|
gru_encode=("Encode sentences with bidirectional GRU", "flag", "E", bool),
|
2016-11-01 03:51:54 +03:00
|
|
|
)
|
2017-02-18 12:38:22 +03:00
|
|
|
def main(mode, train_loc, dev_loc,
|
2016-11-12 20:43:37 +03:00
|
|
|
tree_truncate=False,
|
|
|
|
gru_encode=False,
|
2016-11-01 03:51:54 +03:00
|
|
|
max_length=100,
|
|
|
|
nr_hidden=100,
|
|
|
|
dropout=0.2,
|
|
|
|
learn_rate=0.001,
|
|
|
|
batch_size=100,
|
|
|
|
nr_epoch=5):
|
|
|
|
shape = (max_length, nr_hidden, 3)
|
|
|
|
settings = {
|
|
|
|
'lr': learn_rate,
|
|
|
|
'dropout': dropout,
|
|
|
|
'batch_size': batch_size,
|
2016-11-12 20:43:37 +03:00
|
|
|
'nr_epoch': nr_epoch,
|
|
|
|
'tree_truncate': tree_truncate,
|
|
|
|
'gru_encode': gru_encode
|
2016-11-01 03:51:54 +03:00
|
|
|
}
|
|
|
|
if mode == 'train':
|
2017-02-18 12:38:22 +03:00
|
|
|
train(train_loc, dev_loc, shape, settings)
|
2016-11-01 03:51:54 +03:00
|
|
|
elif mode == 'evaluate':
|
2017-02-18 12:38:22 +03:00
|
|
|
correct, total = evaluate(dev_loc)
|
2017-01-31 22:27:13 +03:00
|
|
|
print(correct, '/', total, correct / total)
|
2016-11-01 03:51:54 +03:00
|
|
|
else:
|
2017-02-18 12:38:22 +03:00
|
|
|
demo()
|
2016-11-01 03:51:54 +03:00
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
plac.call(main)
|