Clarify train textcat example

This commit is contained in:
Matthew Honnibal 2017-07-29 21:59:27 +02:00
parent 27abc56e98
commit c16ef0a85c

View File

@ -28,8 +28,8 @@ def train_textcat(tokenizer, textcat,
batch_sizes = compounding(4., 128., 1.001)
for i in range(n_iter):
losses = {}
for batch in minibatch(tqdm.tqdm(train_data, leave=False),
size=batch_sizes):
train_data = tqdm.tqdm(train_data, leave=False) # Progress bar
for batch in minibatch(train_data, size=batch_sizes):
docs, golds = zip(*batch)
textcat.update((docs, None), golds, sgd=optimizer, drop=0.2,
losses=losses)
@ -105,6 +105,5 @@ def main(model_loc=None):
print(doc.cats)
if __name__ == '__main__':
plac.call(main)