mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Fix example
This commit is contained in:
parent
d17546681c
commit
5378949326
|
@ -2,6 +2,7 @@ import plac
|
|||
import collections
|
||||
import random
|
||||
|
||||
import pathlib
|
||||
import cytoolz
|
||||
import numpy
|
||||
from keras.models import Sequential, model_from_json
|
||||
|
@ -46,7 +47,7 @@ class SentimentAnalyser(object):
|
|||
|
||||
|
||||
def get_features(docs, max_length):
|
||||
Xs = numpy.zeros(len(docs), max_length, dtype='int32')
|
||||
Xs = numpy.zeros((len(list(docs)), max_length), dtype='int32')
|
||||
for i, doc in enumerate(docs):
|
||||
for j, token in enumerate(doc[:max_length]):
|
||||
Xs[i, j] = token.rank if token.has_vector else 0
|
||||
|
@ -69,8 +70,8 @@ def compile_lstm(embeddings, shape, settings):
|
|||
model = Sequential()
|
||||
model.add(
|
||||
Embedding(
|
||||
embeddings.shape[1],
|
||||
embeddings.shape[0],
|
||||
embeddings.shape[1],
|
||||
input_length=shape['max_length'],
|
||||
trainable=False,
|
||||
weights=[embeddings]
|
||||
|
@ -142,6 +143,9 @@ def main(model_dir, train_dir, dev_dir,
|
|||
nr_hidden=64, max_length=100, # Shape
|
||||
dropout=0.5, # General NN config
|
||||
nb_epoch=5, batch_size=100, nr_examples=-1): # Training params
|
||||
model_dir = pathlib.Path(model_dir)
|
||||
train_dir = pathlib.Path(train_dir)
|
||||
dev_dir = pathlib.Path(dev_dir)
|
||||
if is_runtime:
|
||||
dev_texts, dev_labels = read_data(dev_dir)
|
||||
demonstrate_runtime(model_dir, dev_texts)
|
||||
|
@ -149,7 +153,7 @@ def main(model_dir, train_dir, dev_dir,
|
|||
train_texts, train_labels = read_data(train_dir, limit=nr_examples)
|
||||
dev_texts, dev_labels = read_data(dev_dir)
|
||||
lstm = train(train_texts, train_labels, dev_texts, dev_labels,
|
||||
{'nr_hidden': nr_hidden, 'max_length': max_length},
|
||||
{'nr_hidden': nr_hidden, 'max_length': max_length, 'nr_class': 1},
|
||||
{'dropout': 0.5},
|
||||
{},
|
||||
nb_epoch=nb_epoch, batch_size=batch_size)
|
||||
|
|
Loading…
Reference in New Issue
Block a user