diff --git a/examples/training/train_textcat.py b/examples/training/train_textcat.py index 033cc50a9..eefae111f 100644 --- a/examples/training/train_textcat.py +++ b/examples/training/train_textcat.py @@ -28,8 +28,8 @@ def train_textcat(tokenizer, textcat, batch_sizes = compounding(4., 128., 1.001) for i in range(n_iter): losses = {} - for batch in minibatch(tqdm.tqdm(train_data, leave=False), - size=batch_sizes): + train_data = tqdm.tqdm(train_data, leave=False) # Progress bar + for batch in minibatch(train_data, size=batch_sizes): docs, golds = zip(*batch) textcat.update((docs, None), golds, sgd=optimizer, drop=0.2, losses=losses) @@ -70,7 +70,7 @@ def load_data(): texts, labels = zip(*train_data) cats = [(['POSITIVE'] if y else []) for y in labels] - + split = int(len(train_data) * 0.8) train_texts = texts[:split] @@ -104,7 +104,6 @@ def main(model_loc=None): doc = nlp(u'This movie sucked!') print(doc.cats) - if __name__ == '__main__': plac.call(main)