mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-11 16:52:21 +03:00
Fix deep learning example
This commit is contained in:
parent
5378949326
commit
213027a1a1
|
@ -7,6 +7,7 @@ import cytoolz
|
||||||
import numpy
|
import numpy
|
||||||
from keras.models import Sequential, model_from_json
|
from keras.models import Sequential, model_from_json
|
||||||
from keras.layers import LSTM, Dense, Embedding, Dropout, Bidirectional
|
from keras.layers import LSTM, Dense, Embedding, Dropout, Bidirectional
|
||||||
|
from keras.optimizers import Adam
|
||||||
import cPickle as pickle
|
import cPickle as pickle
|
||||||
|
|
||||||
import spacy
|
import spacy
|
||||||
|
@ -61,7 +62,7 @@ def train(train_texts, train_labels, dev_texts, dev_labels,
|
||||||
model = compile_lstm(embeddings, lstm_shape, lstm_settings)
|
model = compile_lstm(embeddings, lstm_shape, lstm_settings)
|
||||||
train_X = get_features(nlp.pipe(train_texts), lstm_shape['max_length'])
|
train_X = get_features(nlp.pipe(train_texts), lstm_shape['max_length'])
|
||||||
dev_X = get_features(nlp.pipe(dev_texts), lstm_shape['max_length'])
|
dev_X = get_features(nlp.pipe(dev_texts), lstm_shape['max_length'])
|
||||||
model.fit(train_X, train_labels, dev_X, dev_labels,
|
model.fit(train_X, train_labels, validation_data=(dev_X, dev_labels),
|
||||||
nb_epoch=nb_epoch, batch_size=batch_size)
|
nb_epoch=nb_epoch, batch_size=batch_size)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@ -80,6 +81,8 @@ def compile_lstm(embeddings, shape, settings):
|
||||||
model.add(Bidirectional(LSTM(shape['nr_hidden'])))
|
model.add(Bidirectional(LSTM(shape['nr_hidden'])))
|
||||||
model.add(Dropout(settings['dropout']))
|
model.add(Dropout(settings['dropout']))
|
||||||
model.add(Dense(shape['nr_class'], activation='sigmoid'))
|
model.add(Dense(shape['nr_class'], activation='sigmoid'))
|
||||||
|
model.compile(optimizer=Adam(lr=settings['lr']), loss='binary_crossentropy',
|
||||||
|
metrics=['accuracy'])
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ -134,6 +137,7 @@ def read_data(data_dir, limit=0):
|
||||||
nr_hidden=("Number of hidden units", "option", "H", int),
|
nr_hidden=("Number of hidden units", "option", "H", int),
|
||||||
max_length=("Maximum sentence length", "option", "L", int),
|
max_length=("Maximum sentence length", "option", "L", int),
|
||||||
dropout=("Dropout", "option", "d", float),
|
dropout=("Dropout", "option", "d", float),
|
||||||
|
learn_rate=("Learn rate", "option", "e", float),
|
||||||
nb_epoch=("Number of training epochs", "option", "i", int),
|
nb_epoch=("Number of training epochs", "option", "i", int),
|
||||||
batch_size=("Size of minibatches for training LSTM", "option", "b", int),
|
batch_size=("Size of minibatches for training LSTM", "option", "b", int),
|
||||||
nr_examples=("Limit to N examples", "option", "n", int)
|
nr_examples=("Limit to N examples", "option", "n", int)
|
||||||
|
@ -141,7 +145,7 @@ def read_data(data_dir, limit=0):
|
||||||
def main(model_dir, train_dir, dev_dir,
|
def main(model_dir, train_dir, dev_dir,
|
||||||
is_runtime=False,
|
is_runtime=False,
|
||||||
nr_hidden=64, max_length=100, # Shape
|
nr_hidden=64, max_length=100, # Shape
|
||||||
dropout=0.5, # General NN config
|
dropout=0.5, learn_rate=0.001, # General NN config
|
||||||
nb_epoch=5, batch_size=100, nr_examples=-1): # Training params
|
nb_epoch=5, batch_size=100, nr_examples=-1): # Training params
|
||||||
model_dir = pathlib.Path(model_dir)
|
model_dir = pathlib.Path(model_dir)
|
||||||
train_dir = pathlib.Path(train_dir)
|
train_dir = pathlib.Path(train_dir)
|
||||||
|
@ -152,9 +156,11 @@ def main(model_dir, train_dir, dev_dir,
|
||||||
else:
|
else:
|
||||||
train_texts, train_labels = read_data(train_dir, limit=nr_examples)
|
train_texts, train_labels = read_data(train_dir, limit=nr_examples)
|
||||||
dev_texts, dev_labels = read_data(dev_dir)
|
dev_texts, dev_labels = read_data(dev_dir)
|
||||||
|
train_labels = numpy.asarray(train_labels, dtype='int32')
|
||||||
|
dev_labels = numpy.asarray(dev_labels, dtype='int32')
|
||||||
lstm = train(train_texts, train_labels, dev_texts, dev_labels,
|
lstm = train(train_texts, train_labels, dev_texts, dev_labels,
|
||||||
{'nr_hidden': nr_hidden, 'max_length': max_length, 'nr_class': 1},
|
{'nr_hidden': nr_hidden, 'max_length': max_length, 'nr_class': 1},
|
||||||
{'dropout': 0.5},
|
{'dropout': 0.5, 'lr': learn_rate},
|
||||||
{},
|
{},
|
||||||
nb_epoch=nb_epoch, batch_size=batch_size)
|
nb_epoch=nb_epoch, batch_size=batch_size)
|
||||||
weights = lstm.get_weights()
|
weights = lstm.get_weights()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user