diff --git a/examples/training/train_textcat.py b/examples/training/train_textcat.py index c678632cd..6745ddba6 100644 --- a/examples/training/train_textcat.py +++ b/examples/training/train_textcat.py @@ -73,10 +73,12 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None textcat.model.tok2vec.from_bytes(file_.read()) print("Training the model...") print("{:^5}\t{:^5}\t{:^5}\t{:^5}".format("LOSS", "P", "R", "F")) + batch_sizes = compounding(4.0, 32.0, 1.001) for i in range(n_iter): losses = {} # batch up the examples using spaCy's minibatch - batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001)) + random.shuffle(train_data) + batches = minibatch(train_data, size=batch_sizes) for batch in batches: texts, annotations = zip(*batch) nlp.update(texts, annotations, sgd=optimizer, drop=0.2, losses=losses)