spaCy/examples/keras_parikh_entailment/__main__.py

128 lines
4.4 KiB
Python
Raw Normal View History

2016-11-01 03:51:54 +03:00
from __future__ import division, unicode_literals, print_function
import spacy
import plac
from pathlib import Path
import ujson as json
import numpy
from keras.utils.np_utils import to_categorical
2016-11-01 03:51:54 +03:00
from spacy_hook import get_embeddings, get_word_ids
from spacy_hook import create_similarity_pipeline
2016-11-12 02:13:12 +03:00
from keras_decomposable_attention import build_model
2016-11-01 03:51:54 +03:00
2016-11-01 03:51:54 +03:00
def train(model_dir, train_loc, dev_loc, shape, settings):
train_texts1, train_texts2, train_labels = read_snli(train_loc)
dev_texts1, dev_texts2, dev_labels = read_snli(dev_loc)
2016-11-01 03:51:54 +03:00
print("Loading spaCy")
nlp = spacy.load('en')
2016-11-01 03:51:54 +03:00
print("Compiling network")
model = build_model(get_embeddings(nlp.vocab), shape, settings)
print("Processing texts...")
2016-11-13 17:52:20 +03:00
Xs = []
for texts in (train_texts1, train_texts2, dev_texts1, dev_texts2):
Xs.append(get_word_ids(list(nlp.pipe(texts, n_threads=20, batch_size=20000)),
max_length=shape[0],
rnn_encode=settings['gru_encode'],
tree_truncate=settings['tree_truncate']))
train_X1, train_X2, dev_X1, dev_X2 = Xs
print(settings)
2016-11-01 03:51:54 +03:00
model.fit(
[train_X1, train_X2],
2016-11-01 03:51:54 +03:00
train_labels,
validation_data=([dev_X1, dev_X2], dev_labels),
2016-11-01 03:51:54 +03:00
nb_epoch=settings['nr_epoch'],
batch_size=settings['batch_size'])
def evaluate(model_dir, dev_loc):
nlp = spacy.load('en', path=model_dir,
tagger=False, parser=False, entity=False, matcher=False,
create_pipeline=create_similarity_pipeline)
n = 0
correct = 0
for (text1, text2), label in zip(dev_texts, dev_labels):
doc1 = nlp(text1)
doc2 = nlp(text2)
sim = doc1.similarity(doc2)
if bool(sim >= 0.5) == label:
correct += 1
n += 1
return correct, total
def demo(model_dir):
nlp = spacy.load('en', path=model_dir,
tagger=False, parser=False, entity=False, matcher=False,
create_pipeline=create_similarity_pipeline)
doc1 = nlp(u'Worst fries ever! Greasy and horrible...')
doc2 = nlp(u'The milkshakes are good. The fries are bad.')
print('doc1.similarity(doc2)', doc1.similarity(doc2))
sent1a, sent1b = doc1.sents
print('sent1a.similarity(sent1b)', sent1a.similarity(sent1b))
print('sent1a.similarity(doc2)', sent1a.similarity(doc2))
print('sent1b.similarity(doc2)', sent1b.similarity(doc2))
LABELS = {'entailment': 0, 'contradiction': 1, 'neutral': 2}
def read_snli(path):
texts1 = []
texts2 = []
labels = []
with path.open() as file_:
2016-11-01 03:51:54 +03:00
for line in file_:
eg = json.loads(line)
label = eg['gold_label']
if label == '-':
continue
texts1.append(eg['sentence1'])
texts2.append(eg['sentence2'])
labels.append(LABELS[label])
return texts1, texts2, to_categorical(numpy.asarray(labels, dtype='int32'))
2016-11-01 03:51:54 +03:00
@plac.annotations(
mode=("Mode to execute", "positional", None, str, ["train", "evaluate", "demo"]),
model_dir=("Path to spaCy model directory", "positional", None, Path),
train_loc=("Path to training data", "positional", None, Path),
dev_loc=("Path to development data", "positional", None, Path),
max_length=("Length to truncate sentences", "option", "L", int),
nr_hidden=("Number of hidden units", "option", "H", int),
dropout=("Dropout level", "option", "d", float),
learn_rate=("Learning rate", "option", "e", float),
batch_size=("Batch size for neural network training", "option", "b", float),
2016-11-13 17:52:20 +03:00
nr_epoch=("Number of training epochs", "option", "i", int),
tree_truncate=("Truncate sentences by tree distance", "flag", "T", bool),
gru_encode=("Encode sentences with bidirectional GRU", "flag", "E", bool),
2016-11-01 03:51:54 +03:00
)
def main(mode, model_dir, train_loc, dev_loc,
tree_truncate=False,
gru_encode=False,
2016-11-01 03:51:54 +03:00
max_length=100,
nr_hidden=100,
dropout=0.2,
learn_rate=0.001,
batch_size=100,
nr_epoch=5):
shape = (max_length, nr_hidden, 3)
settings = {
'lr': learn_rate,
'dropout': dropout,
'batch_size': batch_size,
'nr_epoch': nr_epoch,
'tree_truncate': tree_truncate,
'gru_encode': gru_encode
2016-11-01 03:51:54 +03:00
}
if mode == 'train':
train(model_dir, train_loc, dev_loc, shape, settings)
elif mode == 'evaluate':
evaluate(model_dir, dev_loc)
else:
demo(model_dir)
if __name__ == '__main__':
plac.call(main)