Merge pull request #1391 from explosion/feature/multilabel-textcat

💫 Fix multi-label support for text classification
This commit is contained in:
Matthew Honnibal 2017-10-09 04:22:31 +02:00 committed by GitHub
commit e79fc41ff8
3 changed files with 24 additions and 15 deletions

View File

@ -21,7 +21,6 @@ import thinc.neural._classes.layernorm
thinc.neural._classes.layernorm.set_compat_six_eight(False) thinc.neural._classes.layernorm.set_compat_six_eight(False)
def train_textcat(tokenizer, textcat, def train_textcat(tokenizer, textcat,
train_texts, train_cats, dev_texts, dev_cats, train_texts, train_cats, dev_texts, dev_cats,
n_iter=20): n_iter=20):
@ -57,13 +56,15 @@ def evaluate(tokenizer, textcat, texts, cats):
for i, doc in enumerate(textcat.pipe(docs)): for i, doc in enumerate(textcat.pipe(docs)):
gold = cats[i] gold = cats[i]
for label, score in doc.cats.items(): for label, score in doc.cats.items():
if score >= 0.5 and label in gold: if label not in gold:
continue
if score >= 0.5 and gold[label] >= 0.5:
tp += 1. tp += 1.
elif score >= 0.5 and label not in gold: elif score >= 0.5 and gold[label] < 0.5:
fp += 1. fp += 1.
elif score < 0.5 and label not in gold: elif score < 0.5 and gold[label] < 0.5:
tn += 1 tn += 1
if score < 0.5 and label in gold: elif score < 0.5 and gold[label] >= 0.5:
fn += 1 fn += 1
precis = tp / (tp + fp) precis = tp / (tp + fp)
recall = tp / (tp + fn) recall = tp / (tp + fn)
@ -80,7 +81,7 @@ def load_data(limit=0):
train_data = train_data[-limit:] train_data = train_data[-limit:]
texts, labels = zip(*train_data) texts, labels = zip(*train_data)
cats = [(['POSITIVE'] if y else []) for y in labels] cats = [{'POSITIVE': bool(y)} for y in labels]
split = int(len(train_data) * 0.8) split = int(len(train_data) * 0.8)
@ -97,7 +98,7 @@ def main(model_loc=None):
textcat = TextCategorizer(tokenizer.vocab, labels=['POSITIVE']) textcat = TextCategorizer(tokenizer.vocab, labels=['POSITIVE'])
print("Load IMDB data") print("Load IMDB data")
(train_texts, train_cats), (dev_texts, dev_cats) = load_data(limit=1000) (train_texts, train_cats), (dev_texts, dev_cats) = load_data(limit=2000)
print("Itn.\tLoss\tP\tR\tF") print("Itn.\tLoss\tP\tR\tF")
progress = '{i:d} {loss:.3f} {textcat_p:.3f} {textcat_r:.3f} {textcat_f:.3f}' progress = '{i:d} {loss:.3f} {textcat_p:.3f} {textcat_r:.3f} {textcat_f:.3f}'

View File

@ -387,7 +387,7 @@ cdef class GoldParse:
def __init__(self, doc, annot_tuples=None, words=None, tags=None, heads=None, def __init__(self, doc, annot_tuples=None, words=None, tags=None, heads=None,
deps=None, entities=None, make_projective=False, deps=None, entities=None, make_projective=False,
cats=tuple()): cats=None):
"""Create a GoldParse. """Create a GoldParse.
doc (Doc): The document the annotations refer to. doc (Doc): The document the annotations refer to.
@ -398,12 +398,15 @@ cdef class GoldParse:
entities (iterable): A sequence of named entity annotations, either as entities (iterable): A sequence of named entity annotations, either as
BILUO tag strings, or as `(start_char, end_char, label)` tuples, BILUO tag strings, or as `(start_char, end_char, label)` tuples,
representing the entity positions. representing the entity positions.
cats (iterable): A sequence of labels for text classification. Each cats (dict): Labels for text classification. Each key in the dictionary
label may be a string or an int, or a `(start_char, end_char, label)` may be a string or an int, or a `(start_char, end_char, label)`
tuple, indicating that the label is applied to only part of the tuple, indicating that the label is applied to only part of the
document (usually a sentence). Unlike entity annotations, label document (usually a sentence). Unlike entity annotations, label
annotations can overlap, i.e. a single word can be covered by annotations can overlap, i.e. a single word can be covered by
multiple labelled spans. multiple labelled spans. The TextCategorizer component expects
true examples of a label to have the value 1.0, and negative examples
of a label to have the value 0.0. Labels not in the dictionary are
treated as missing -- the gradient for those labels will be zero.
RETURNS (GoldParse): The newly constructed object. RETURNS (GoldParse): The newly constructed object.
""" """
if words is None: if words is None:
@ -434,7 +437,7 @@ cdef class GoldParse:
self.c.sent_start = <int*>self.mem.alloc(len(doc), sizeof(int)) self.c.sent_start = <int*>self.mem.alloc(len(doc), sizeof(int))
self.c.ner = <Transition*>self.mem.alloc(len(doc), sizeof(Transition)) self.c.ner = <Transition*>self.mem.alloc(len(doc), sizeof(Transition))
self.cats = list(cats) self.cats = {} if cats is None else dict(cats)
self.words = [None] * len(doc) self.words = [None] * len(doc)
self.tags = [None] * len(doc) self.tags = [None] * len(doc)
self.heads = [None] * len(doc) self.heads = [None] * len(doc)

View File

@ -551,7 +551,6 @@ class NeuralLabeller(NeuralTagger):
label = self.make_label(i, words, tags, heads, deps, ents) label = self.make_label(i, words, tags, heads, deps, ents)
if label is not None and label not in self.labels: if label is not None and label not in self.labels:
self.labels[label] = len(self.labels) self.labels[label] = len(self.labels)
print(len(self.labels))
if self.model is True: if self.model is True:
token_vector_width = util.env_opt('token_vector_width') token_vector_width = util.env_opt('token_vector_width')
self.model = chain( self.model = chain(
@ -720,11 +719,17 @@ class TextCategorizer(BaseThincComponent):
def get_loss(self, docs, golds, scores): def get_loss(self, docs, golds, scores):
truths = numpy.zeros((len(golds), len(self.labels)), dtype='f') truths = numpy.zeros((len(golds), len(self.labels)), dtype='f')
not_missing = numpy.ones((len(golds), len(self.labels)), dtype='f')
for i, gold in enumerate(golds): for i, gold in enumerate(golds):
for j, label in enumerate(self.labels): for j, label in enumerate(self.labels):
truths[i, j] = label in gold.cats if label in gold.cats:
truths[i, j] = gold.cats[label]
else:
not_missing[i, j] = 0.
truths = self.model.ops.asarray(truths) truths = self.model.ops.asarray(truths)
not_missing = self.model.ops.asarray(not_missing)
d_scores = (scores-truths) / scores.shape[0] d_scores = (scores-truths) / scores.shape[0]
d_scores *= not_missing
mean_square_error = ((scores-truths)**2).sum(axis=1).mean() mean_square_error = ((scores-truths)**2).sum(axis=1).mean()
return mean_square_error, d_scores return mean_square_error, d_scores