Update config.py

This commit is contained in:
Matthew Honnibal 2016-11-20 03:45:51 +01:00
parent 409a18bd42
commit 001abe2b9d

View File

@ -1,4 +1,5 @@
from paddle.trainer.PyDataProvider2 import * from paddle.trainer.PyDataProvider2 import *
from itertools import izip
def get_features(doc): def get_features(doc):
@ -8,6 +9,14 @@ def get_features(doc):
dtype='int32') dtype='int32')
def read_data(data_dir):
for subdir, label in (('pos', 1), ('neg', 0)):
for filename in (data_dir / subdir).iterdir():
with filename.open() as file_:
text = file_.read()
yield text, label
def on_init(settings, lang_name, **kwargs): def on_init(settings, lang_name, **kwargs):
print("Loading spaCy") print("Loading spaCy")
nlp = spacy.load('en', entity=False) nlp = spacy.load('en', entity=False)
@ -28,8 +37,7 @@ def on_init(settings, lang_name, **kwargs):
@provider(init_hook=on_init) @provider(init_hook=on_init)
def process(settings, data_dir): # settings is not used currently. def process(settings, data_dir): # settings is not used currently.
texts, labels = read_data(data_dir) texts, labels = read_data(data_dir)
for doc, label in zip(nlp.pipe(train_texts, batch_size=5000, n_threads=3), for doc, label in izip(nlp.pipe(texts, batch_size=5000, n_threads=3), labels):
labels):
for sent in doc.sents: for sent in doc.sents:
ids = get_features(sent) ids = get_features(sent)
# give data to paddle. # give data to paddle.