From bc8cda818c6d754a46bc6ffa2dbccd5a01181968 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 4 Nov 2018 00:17:09 +0000 Subject: [PATCH] Improve pretrain textcat example --- examples/training/pretrain_textcat.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/examples/training/pretrain_textcat.py b/examples/training/pretrain_textcat.py index c272145f5..8dc4e0786 100644 --- a/examples/training/pretrain_textcat.py +++ b/examples/training/pretrain_textcat.py @@ -47,8 +47,8 @@ def load_textcat_data(limit=0): train_data = train_data[-limit:] texts, labels = zip(*train_data) eval_texts, eval_labels = zip(*eval_data) - cats = [{'POSITIVE': bool(y)} for y in labels] - eval_cats = [{'POSITIVE': bool(y)} for y in eval_labels] + cats = [{'POSITIVE': bool(y), 'NEGATIVE': not bool(y)} for y in labels] + eval_cats = [{'POSITIVE': bool(y), 'NEGATIVE': not bool(y)} for y in eval_labels] return (texts, cats), (eval_texts, eval_cats) @@ -61,9 +61,9 @@ def prefer_gpu(): def build_textcat_model(tok2vec, nr_class, width): - from thinc.v2v import Model, Affine, Maxout + from thinc.v2v import Model, Softmax, Maxout from thinc.api import flatten_add_lengths, chain - from thinc.t2v import Pooling, sum_pool, max_pool + from thinc.t2v import Pooling, sum_pool, mean_pool, max_pool from thinc.misc import Residual, LayerNorm from spacy._ml import logistic, zero_init @@ -71,11 +71,8 @@ def build_textcat_model(tok2vec, nr_class, width): model = ( tok2vec >> flatten_add_lengths - >> Pooling(sum_pool, max_pool) - >> Residual(LayerNorm(Maxout(width*2, width*2, pieces=3))) - >> Residual(LayerNorm(Maxout(width*2, width*2, pieces=3))) - >> zero_init(Affine(nr_class, width*2, drop_factor=0.0)) - >> logistic + >> Pooling(mean_pool) + >> Softmax(nr_class, width) ) model.tok2vec = tok2vec return model @@ -92,9 +89,9 @@ def create_pipeline(width, embed_size, vectors_model): nlp = spacy.load(vectors_model) print("Start training") textcat = TextCategorizer(nlp.vocab, - labels=['POSITIVE'], + labels=['POSITIVE', 'NEGATIVE'], model=build_textcat_model( - Tok2Vec(width=width, embed_size=embed_size), 1, width)) + Tok2Vec(width=width, embed_size=embed_size), 2, width)) nlp.add_pipe(textcat) return nlp