mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-11 16:52:21 +03:00
Fix example
This commit is contained in:
parent
d17546681c
commit
5378949326
|
@ -2,6 +2,7 @@ import plac
|
||||||
import collections
|
import collections
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
import pathlib
|
||||||
import cytoolz
|
import cytoolz
|
||||||
import numpy
|
import numpy
|
||||||
from keras.models import Sequential, model_from_json
|
from keras.models import Sequential, model_from_json
|
||||||
|
@ -46,7 +47,7 @@ class SentimentAnalyser(object):
|
||||||
|
|
||||||
|
|
||||||
def get_features(docs, max_length):
|
def get_features(docs, max_length):
|
||||||
Xs = numpy.zeros(len(docs), max_length, dtype='int32')
|
Xs = numpy.zeros((len(list(docs)), max_length), dtype='int32')
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
for j, token in enumerate(doc[:max_length]):
|
for j, token in enumerate(doc[:max_length]):
|
||||||
Xs[i, j] = token.rank if token.has_vector else 0
|
Xs[i, j] = token.rank if token.has_vector else 0
|
||||||
|
@ -69,8 +70,8 @@ def compile_lstm(embeddings, shape, settings):
|
||||||
model = Sequential()
|
model = Sequential()
|
||||||
model.add(
|
model.add(
|
||||||
Embedding(
|
Embedding(
|
||||||
embeddings.shape[1],
|
|
||||||
embeddings.shape[0],
|
embeddings.shape[0],
|
||||||
|
embeddings.shape[1],
|
||||||
input_length=shape['max_length'],
|
input_length=shape['max_length'],
|
||||||
trainable=False,
|
trainable=False,
|
||||||
weights=[embeddings]
|
weights=[embeddings]
|
||||||
|
@ -142,6 +143,9 @@ def main(model_dir, train_dir, dev_dir,
|
||||||
nr_hidden=64, max_length=100, # Shape
|
nr_hidden=64, max_length=100, # Shape
|
||||||
dropout=0.5, # General NN config
|
dropout=0.5, # General NN config
|
||||||
nb_epoch=5, batch_size=100, nr_examples=-1): # Training params
|
nb_epoch=5, batch_size=100, nr_examples=-1): # Training params
|
||||||
|
model_dir = pathlib.Path(model_dir)
|
||||||
|
train_dir = pathlib.Path(train_dir)
|
||||||
|
dev_dir = pathlib.Path(dev_dir)
|
||||||
if is_runtime:
|
if is_runtime:
|
||||||
dev_texts, dev_labels = read_data(dev_dir)
|
dev_texts, dev_labels = read_data(dev_dir)
|
||||||
demonstrate_runtime(model_dir, dev_texts)
|
demonstrate_runtime(model_dir, dev_texts)
|
||||||
|
@ -149,7 +153,7 @@ def main(model_dir, train_dir, dev_dir,
|
||||||
train_texts, train_labels = read_data(train_dir, limit=nr_examples)
|
train_texts, train_labels = read_data(train_dir, limit=nr_examples)
|
||||||
dev_texts, dev_labels = read_data(dev_dir)
|
dev_texts, dev_labels = read_data(dev_dir)
|
||||||
lstm = train(train_texts, train_labels, dev_texts, dev_labels,
|
lstm = train(train_texts, train_labels, dev_texts, dev_labels,
|
||||||
{'nr_hidden': nr_hidden, 'max_length': max_length},
|
{'nr_hidden': nr_hidden, 'max_length': max_length, 'nr_class': 1},
|
||||||
{'dropout': 0.5},
|
{'dropout': 0.5},
|
||||||
{},
|
{},
|
||||||
nb_epoch=nb_epoch, batch_size=batch_size)
|
nb_epoch=nb_epoch, batch_size=batch_size)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user