From 563f46f026054a73289bca64d7d6cbc2cca07150 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 5 Oct 2017 18:43:02 -0500 Subject: [PATCH] Fix multi-label support for text classification The TextCategorizer class is supposed to support multi-label text classification, and allow training data to contain missing values. For this to work, the gradient of the loss should be 0 when labels are missing. Instead, there was no way to actually denote "missing" in the GoldParse class, and so the TextCategorizer class treated the label set within gold.cats as complete. To fix this, we change GoldParse.cats to be a dict instead of a list. The GoldParse.cats dict should map to floats, with 1. denoting 'present' and 0. denoting 'absent'. Gradients are zeroed for categories absent from the gold.cats dict. A nice bonus is that you can also set values between 0 and 1 for partial membership. You can also set numeric values, if you're using a text classification model that uses an appropriate loss function. Unfortunately this is a breaking change; although the functionality was only recently introduced and hasn't been properly documented yet. I've updated the example script accordingly. --- examples/training/train_textcat.py | 17 +++++++++-------- spacy/gold.pyx | 13 ++++++++----- spacy/pipeline.pyx | 9 +++++++-- 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/examples/training/train_textcat.py b/examples/training/train_textcat.py index 6018827a4..4d07ed26a 100644 --- a/examples/training/train_textcat.py +++ b/examples/training/train_textcat.py @@ -21,7 +21,6 @@ import thinc.neural._classes.layernorm thinc.neural._classes.layernorm.set_compat_six_eight(False) - def train_textcat(tokenizer, textcat, train_texts, train_cats, dev_texts, dev_cats, n_iter=20): @@ -57,18 +56,20 @@ def evaluate(tokenizer, textcat, texts, cats): for i, doc in enumerate(textcat.pipe(docs)): gold = cats[i] for label, score in doc.cats.items(): - if score >= 0.5 and label in gold: + if label not in gold: + continue + if score >= 0.5 and gold[label] >= 0.5: tp += 1. - elif score >= 0.5 and label not in gold: + elif score >= 0.5 and gold[label] < 0.5: fp += 1. - elif score < 0.5 and label not in gold: + elif score < 0.5 and gold[label] < 0.5: tn += 1 - if score < 0.5 and label in gold: + elif score < 0.5 and gold[label] >= 0.5: fn += 1 precis = tp / (tp + fp) recall = tp / (tp + fn) fscore = 2 * (precis * recall) / (precis + recall) - return {'textcat_p': precis, 'textcat_r': recall, 'textcat_f': fscore} + return {'textcat_p': precis, 'textcat_r': recall, 'textcat_f': fscore} def load_data(limit=0): @@ -80,7 +81,7 @@ def load_data(limit=0): train_data = train_data[-limit:] texts, labels = zip(*train_data) - cats = [(['POSITIVE'] if y else []) for y in labels] + cats = [{'POSITIVE': bool(y)} for y in labels] split = int(len(train_data) * 0.8) @@ -97,7 +98,7 @@ def main(model_loc=None): textcat = TextCategorizer(tokenizer.vocab, labels=['POSITIVE']) print("Load IMDB data") - (train_texts, train_cats), (dev_texts, dev_cats) = load_data(limit=1000) + (train_texts, train_cats), (dev_texts, dev_cats) = load_data(limit=2000) print("Itn.\tLoss\tP\tR\tF") progress = '{i:d} {loss:.3f} {textcat_p:.3f} {textcat_r:.3f} {textcat_f:.3f}' diff --git a/spacy/gold.pyx b/spacy/gold.pyx index fc8d6622b..2512c179f 100644 --- a/spacy/gold.pyx +++ b/spacy/gold.pyx @@ -387,7 +387,7 @@ cdef class GoldParse: def __init__(self, doc, annot_tuples=None, words=None, tags=None, heads=None, deps=None, entities=None, make_projective=False, - cats=tuple()): + cats=None): """Create a GoldParse. doc (Doc): The document the annotations refer to. @@ -398,12 +398,15 @@ cdef class GoldParse: entities (iterable): A sequence of named entity annotations, either as BILUO tag strings, or as `(start_char, end_char, label)` tuples, representing the entity positions. - cats (iterable): A sequence of labels for text classification. Each - label may be a string or an int, or a `(start_char, end_char, label)` + cats (dict): Labels for text classification. Each key in the dictionary + may be a string or an int, or a `(start_char, end_char, label)` tuple, indicating that the label is applied to only part of the document (usually a sentence). Unlike entity annotations, label annotations can overlap, i.e. a single word can be covered by - multiple labelled spans. + multiple labelled spans. The TextCategorizer component expects + true examples of a label to have the value 1.0, and negative examples + of a label to have the value 0.0. Labels not in the dictionary are + treated as missing -- the gradient for those labels will be zero. RETURNS (GoldParse): The newly constructed object. """ if words is None: @@ -434,7 +437,7 @@ cdef class GoldParse: self.c.sent_start = self.mem.alloc(len(doc), sizeof(int)) self.c.ner = self.mem.alloc(len(doc), sizeof(Transition)) - self.cats = list(cats) + self.cats = {} if cats is None else dict(cats) self.words = [None] * len(doc) self.tags = [None] * len(doc) self.heads = [None] * len(doc) diff --git a/spacy/pipeline.pyx b/spacy/pipeline.pyx index 8d935335c..c39976630 100644 --- a/spacy/pipeline.pyx +++ b/spacy/pipeline.pyx @@ -551,7 +551,6 @@ class NeuralLabeller(NeuralTagger): label = self.make_label(i, words, tags, heads, deps, ents) if label is not None and label not in self.labels: self.labels[label] = len(self.labels) - print(len(self.labels)) if self.model is True: token_vector_width = util.env_opt('token_vector_width') self.model = chain( @@ -720,11 +719,17 @@ class TextCategorizer(BaseThincComponent): def get_loss(self, docs, golds, scores): truths = numpy.zeros((len(golds), len(self.labels)), dtype='f') + not_missing = numpy.ones((len(golds), len(self.labels)), dtype='f') for i, gold in enumerate(golds): for j, label in enumerate(self.labels): - truths[i, j] = label in gold.cats + if label in gold.cats: + truths[i, j] = gold.cats[label] + else: + not_missing[i, j] = 0. truths = self.model.ops.asarray(truths) + not_missing = self.model.ops.asarray(not_missing) d_scores = (scores-truths) / scores.shape[0] + d_scores *= not_missing mean_square_error = ((scores-truths)**2).sum(axis=1).mean() return mean_square_error, d_scores