Fix example

This commit is contained in:
Matthew Honnibal 2016-10-20 03:42:34 +02:00
parent d17546681c
commit 5378949326

View File

@ -2,6 +2,7 @@ import plac
import collections
import random
import pathlib
import cytoolz
import numpy
from keras.models import Sequential, model_from_json
@ -46,7 +47,7 @@ class SentimentAnalyser(object):
def get_features(docs, max_length):
Xs = numpy.zeros(len(docs), max_length, dtype='int32')
Xs = numpy.zeros((len(list(docs)), max_length), dtype='int32')
for i, doc in enumerate(docs):
for j, token in enumerate(doc[:max_length]):
Xs[i, j] = token.rank if token.has_vector else 0
@ -69,8 +70,8 @@ def compile_lstm(embeddings, shape, settings):
model = Sequential()
model.add(
Embedding(
embeddings.shape[1],
embeddings.shape[0],
embeddings.shape[1],
input_length=shape['max_length'],
trainable=False,
weights=[embeddings]
@ -142,6 +143,9 @@ def main(model_dir, train_dir, dev_dir,
nr_hidden=64, max_length=100, # Shape
dropout=0.5, # General NN config
nb_epoch=5, batch_size=100, nr_examples=-1): # Training params
model_dir = pathlib.Path(model_dir)
train_dir = pathlib.Path(train_dir)
dev_dir = pathlib.Path(dev_dir)
if is_runtime:
dev_texts, dev_labels = read_data(dev_dir)
demonstrate_runtime(model_dir, dev_texts)
@ -149,7 +153,7 @@ def main(model_dir, train_dir, dev_dir,
train_texts, train_labels = read_data(train_dir, limit=nr_examples)
dev_texts, dev_labels = read_data(dev_dir)
lstm = train(train_texts, train_labels, dev_texts, dev_labels,
{'nr_hidden': nr_hidden, 'max_length': max_length},
{'nr_hidden': nr_hidden, 'max_length': max_length, 'nr_class': 1},
{'dropout': 0.5},
{},
nb_epoch=nb_epoch, batch_size=batch_size)