💫 Make TextCategorizer default to a simpler, GPU-friendly model (#3038)

Currently the TextCategorizer defaults to a fairly complicated model, designed partly around the active learning requirements of Prodigy. The model's a bit slow, and not very GPU-friendly.

This patch implements a straightforward CNN model that still performs pretty well. The replacement model also makes it easy to use the LMAO pretraining, since most of the parameters are in the CNN.

The replacement model has a flag to specify whether labels are mutually exclusive, which defaults to True. This has been a common problem with the text classifier. We'll also now be able to support adding labels to pretrained models again.

Resolves #2934, #2756, #1798, #1748.
This commit is contained in:
Matthew Honnibal 2018-12-10 14:37:39 +01:00 committed by GitHub
parent b1c8731b4d
commit 375f0dc529
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 8 deletions

View File

@ -44,6 +44,7 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000):
# add label to text classifier
textcat.add_label("POSITIVE")
textcat.add_label("NEGATIVE")
# load the IMDB dataset
print("Loading IMDB data...")
@ -64,7 +65,7 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000):
for i in range(n_iter):
losses = {}
# batch up the examples using spaCy's minibatch
batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001))
batches = minibatch(train_data, size=compounding(4.0, 16.0, 1.001))
for batch in batches:
texts, annotations = zip(*batch)
nlp.update(texts, annotations, sgd=optimizer, drop=0.2, losses=losses)
@ -106,22 +107,24 @@ def load_data(limit=0, split=0.8):
random.shuffle(train_data)
train_data = train_data[-limit:]
texts, labels = zip(*train_data)
cats = [{"POSITIVE": bool(y)} for y in labels]
cats = [{"POSITIVE": bool(y), "NEGATIVE": not bool(y)} for y in labels]
split = int(len(train_data) * split)
return (texts[:split], cats[:split]), (texts[split:], cats[split:])
def evaluate(tokenizer, textcat, texts, cats):
docs = (tokenizer(text) for text in texts)
tp = 1e-8 # True positives
tp = 0.0 # True positives
fp = 1e-8 # False positives
fn = 1e-8 # False negatives
tn = 1e-8 # True negatives
tn = 0.0 # True negatives
for i, doc in enumerate(textcat.pipe(docs)):
gold = cats[i]
for label, score in doc.cats.items():
if label not in gold:
continue
if label == "NEGATIVE":
continue
if score >= 0.5 and gold[label] >= 0.5:
tp += 1.0
elif score >= 0.5 and gold[label] < 0.5:

View File

@ -5,7 +5,7 @@ import numpy
from thinc.v2v import Model, Maxout, Softmax, Affine, ReLu
from thinc.i2v import HashEmbed, StaticVectors
from thinc.t2t import ExtractWindow, ParametricAttention
from thinc.t2v import Pooling, sum_pool
from thinc.t2v import Pooling, sum_pool, mean_pool
from thinc.misc import Residual
from thinc.misc import LayerNorm as LN
from thinc.misc import FeatureExtracter
@ -575,6 +575,32 @@ def build_text_classifier(nr_class, width=64, **cfg):
return model
def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=True, **cfg):
"""
Build a simple CNN text classifier, given a token-to-vector model as inputs.
If exclusive_classes=True, a softmax non-linearity is applied, so that the
outputs sum to 1. If exclusive_classes=False, a logistic non-linearity
is applied instead, so that outputs are in the range [0, 1].
"""
with Model.define_operators({">>": chain}):
if exclusive_classes:
output_layer = Softmax(nr_class, tok2vec.nO)
else:
output_layer = (
zero_init(Affine(nr_class, tok2vec.nO))
>> logistic
)
model = (
tok2vec
>> flatten_add_lengths
>> Pooling(mean_pool)
>> output_layer
)
model.tok2vec = chain(tok2vec, flatten)
model.nO = nr_class
return model
@layerize
def flatten(seqs, drop=0.0):
ops = Model.ops

View File

@ -30,6 +30,7 @@ from .tokens.span import Span
from .attrs import POS, ID
from .parts_of_speech import X
from ._ml import Tok2Vec, build_text_classifier, build_tagger_model
from ._ml import build_simple_cnn_text_classifier
from ._ml import link_vectors_to_models, zero_init, flatten
from ._ml import create_default_optimizer
from .errors import Errors, TempErrors
@ -1043,15 +1044,20 @@ class TextCategorizer(Pipe):
@classmethod
def Model(cls, nr_class, **cfg):
return build_text_classifier(nr_class, **cfg)
embed_size = util.env_opt("embed_size", 2000)
if "token_vector_width" in cfg:
token_vector_width = cfg["token_vector_width"]
else:
token_vector_width = util.env_opt("token_vector_width", 96)
tok2vec = Tok2Vec(token_vector_width, embed_size, **cfg)
return build_simple_cnn_text_classifier(tok2vec, nr_class, **cfg)
@property
def tok2vec(self):
if self.model in (None, True, False):
return None
else:
return chain(self.model.tok2vec, flatten)
return self.model.tok2vec
def __init__(self, vocab, model=True, **cfg):
self.vocab = vocab