Clarify train textcat example

This commit is contained in:
Matthew Honnibal 2017-07-29 21:59:27 +02:00
parent 27abc56e98
commit c16ef0a85c

View File

@ -28,8 +28,8 @@ def train_textcat(tokenizer, textcat,
batch_sizes = compounding(4., 128., 1.001) batch_sizes = compounding(4., 128., 1.001)
for i in range(n_iter): for i in range(n_iter):
losses = {} losses = {}
for batch in minibatch(tqdm.tqdm(train_data, leave=False), train_data = tqdm.tqdm(train_data, leave=False) # Progress bar
size=batch_sizes): for batch in minibatch(train_data, size=batch_sizes):
docs, golds = zip(*batch) docs, golds = zip(*batch)
textcat.update((docs, None), golds, sgd=optimizer, drop=0.2, textcat.update((docs, None), golds, sgd=optimizer, drop=0.2,
losses=losses) losses=losses)
@ -70,7 +70,7 @@ def load_data():
texts, labels = zip(*train_data) texts, labels = zip(*train_data)
cats = [(['POSITIVE'] if y else []) for y in labels] cats = [(['POSITIVE'] if y else []) for y in labels]
split = int(len(train_data) * 0.8) split = int(len(train_data) * 0.8)
train_texts = texts[:split] train_texts = texts[:split]
@ -104,7 +104,6 @@ def main(model_loc=None):
doc = nlp(u'This movie sucked!') doc = nlp(u'This movie sucked!')
print(doc.cats) print(doc.cats)
if __name__ == '__main__': if __name__ == '__main__':
plac.call(main) plac.call(main)