mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Fix train loop for train_textcat example
This commit is contained in:
parent
9cee3f702a
commit
4c5f265884
|
@ -73,10 +73,12 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None
|
|||
textcat.model.tok2vec.from_bytes(file_.read())
|
||||
print("Training the model...")
|
||||
print("{:^5}\t{:^5}\t{:^5}\t{:^5}".format("LOSS", "P", "R", "F"))
|
||||
batch_sizes = compounding(4.0, 32.0, 1.001)
|
||||
for i in range(n_iter):
|
||||
losses = {}
|
||||
# batch up the examples using spaCy's minibatch
|
||||
batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001))
|
||||
random.shuffle(train_data)
|
||||
batches = minibatch(train_data, size=batch_sizes)
|
||||
for batch in batches:
|
||||
texts, annotations = zip(*batch)
|
||||
nlp.update(texts, annotations, sgd=optimizer, drop=0.2, losses=losses)
|
||||
|
|
Loading…
Reference in New Issue
Block a user