Update textcat exampe

This commit is contained in:
Matthew Honnibal 2017-10-04 14:55:30 +02:00
parent 774f5732bd
commit 79a94bc166

View File

@ -1,3 +1,7 @@
'''Train a multi-label convolutional neural network text classifier,
using the spacy.pipeline.TextCategorizer component. The model is then added
to spacy.pipeline, and predictions are available at `doc.cats`.
'''
from __future__ import unicode_literals
import plac
import random
@ -31,7 +35,7 @@ def train_textcat(tokenizer, textcat,
train_data = tqdm.tqdm(train_data, leave=False) # Progress bar
for batch in minibatch(train_data, size=batch_sizes):
docs, golds = zip(*batch)
textcat.update((docs, None), golds, sgd=optimizer, drop=0.2,
textcat.update(docs, golds, sgd=optimizer, drop=0.2,
losses=losses)
with textcat.model.use_params(optimizer.averages):
scores = evaluate(tokenizer, textcat, dev_texts, dev_cats)