Fix train loop for train_textcat example

This commit is contained in:
Matthew Honnibal 2019-03-22 16:10:11 +01:00 committed by GitHub
parent 9cee3f702a
commit 4c5f265884
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -73,10 +73,12 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None
textcat.model.tok2vec.from_bytes(file_.read()) textcat.model.tok2vec.from_bytes(file_.read())
print("Training the model...") print("Training the model...")
print("{:^5}\t{:^5}\t{:^5}\t{:^5}".format("LOSS", "P", "R", "F")) print("{:^5}\t{:^5}\t{:^5}\t{:^5}".format("LOSS", "P", "R", "F"))
batch_sizes = compounding(4.0, 32.0, 1.001)
for i in range(n_iter): for i in range(n_iter):
losses = {} losses = {}
# batch up the examples using spaCy's minibatch # batch up the examples using spaCy's minibatch
batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001)) random.shuffle(train_data)
batches = minibatch(train_data, size=batch_sizes)
for batch in batches: for batch in batches:
texts, annotations = zip(*batch) texts, annotations = zip(*batch)
nlp.update(texts, annotations, sgd=optimizer, drop=0.2, losses=losses) nlp.update(texts, annotations, sgd=optimizer, drop=0.2, losses=losses)