diff --git a/examples/paddle/sentiment_bilstm/config.py b/examples/paddle/sentiment_bilstm/config.py index cdee7cdf9..cde30cf61 100644 --- a/examples/paddle/sentiment_bilstm/config.py +++ b/examples/paddle/sentiment_bilstm/config.py @@ -1,4 +1,5 @@ from paddle.trainer.PyDataProvider2 import * +from itertools import izip def get_features(doc): @@ -8,6 +9,14 @@ def get_features(doc): dtype='int32') +def read_data(data_dir): + for subdir, label in (('pos', 1), ('neg', 0)): + for filename in (data_dir / subdir).iterdir(): + with filename.open() as file_: + text = file_.read() + yield text, label + + def on_init(settings, lang_name, **kwargs): print("Loading spaCy") nlp = spacy.load('en', entity=False) @@ -28,8 +37,7 @@ def on_init(settings, lang_name, **kwargs): @provider(init_hook=on_init) def process(settings, data_dir): # settings is not used currently. texts, labels = read_data(data_dir) - for doc, label in zip(nlp.pipe(train_texts, batch_size=5000, n_threads=3), - labels): + for doc, label in izip(nlp.pipe(texts, batch_size=5000, n_threads=3), labels): for sent in doc.sents: ids = get_features(sent) # give data to paddle.