diff --git a/examples/training/train_textcat.py b/examples/training/train_textcat.py index eefae111f..7eb356100 100644 --- a/examples/training/train_textcat.py +++ b/examples/training/train_textcat.py @@ -1,3 +1,7 @@ +'''Train a multi-label convolutional neural network text classifier, +using the spacy.pipeline.TextCategorizer component. The model is then added +to spacy.pipeline, and predictions are available at `doc.cats`. +''' from __future__ import unicode_literals import plac import random @@ -31,7 +35,7 @@ def train_textcat(tokenizer, textcat, train_data = tqdm.tqdm(train_data, leave=False) # Progress bar for batch in minibatch(train_data, size=batch_sizes): docs, golds = zip(*batch) - textcat.update((docs, None), golds, sgd=optimizer, drop=0.2, + textcat.update(docs, golds, sgd=optimizer, drop=0.2, losses=losses) with textcat.model.use_params(optimizer.averages): scores = evaluate(tokenizer, textcat, dev_texts, dev_cats)