mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Fix multi-label support for text classification
The TextCategorizer class is supposed to support multi-label text classification, and allow training data to contain missing values. For this to work, the gradient of the loss should be 0 when labels are missing. Instead, there was no way to actually denote "missing" in the GoldParse class, and so the TextCategorizer class treated the label set within gold.cats as complete. To fix this, we change GoldParse.cats to be a dict instead of a list. The GoldParse.cats dict should map to floats, with 1. denoting 'present' and 0. denoting 'absent'. Gradients are zeroed for categories absent from the gold.cats dict. A nice bonus is that you can also set values between 0 and 1 for partial membership. You can also set numeric values, if you're using a text classification model that uses an appropriate loss function. Unfortunately this is a breaking change; although the functionality was only recently introduced and hasn't been properly documented yet. I've updated the example script accordingly.
This commit is contained in:
parent
fb75eb52f1
commit
563f46f026
|
@ -21,7 +21,6 @@ import thinc.neural._classes.layernorm
|
||||||
thinc.neural._classes.layernorm.set_compat_six_eight(False)
|
thinc.neural._classes.layernorm.set_compat_six_eight(False)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def train_textcat(tokenizer, textcat,
|
def train_textcat(tokenizer, textcat,
|
||||||
train_texts, train_cats, dev_texts, dev_cats,
|
train_texts, train_cats, dev_texts, dev_cats,
|
||||||
n_iter=20):
|
n_iter=20):
|
||||||
|
@ -57,13 +56,15 @@ def evaluate(tokenizer, textcat, texts, cats):
|
||||||
for i, doc in enumerate(textcat.pipe(docs)):
|
for i, doc in enumerate(textcat.pipe(docs)):
|
||||||
gold = cats[i]
|
gold = cats[i]
|
||||||
for label, score in doc.cats.items():
|
for label, score in doc.cats.items():
|
||||||
if score >= 0.5 and label in gold:
|
if label not in gold:
|
||||||
|
continue
|
||||||
|
if score >= 0.5 and gold[label] >= 0.5:
|
||||||
tp += 1.
|
tp += 1.
|
||||||
elif score >= 0.5 and label not in gold:
|
elif score >= 0.5 and gold[label] < 0.5:
|
||||||
fp += 1.
|
fp += 1.
|
||||||
elif score < 0.5 and label not in gold:
|
elif score < 0.5 and gold[label] < 0.5:
|
||||||
tn += 1
|
tn += 1
|
||||||
if score < 0.5 and label in gold:
|
elif score < 0.5 and gold[label] >= 0.5:
|
||||||
fn += 1
|
fn += 1
|
||||||
precis = tp / (tp + fp)
|
precis = tp / (tp + fp)
|
||||||
recall = tp / (tp + fn)
|
recall = tp / (tp + fn)
|
||||||
|
@ -80,7 +81,7 @@ def load_data(limit=0):
|
||||||
train_data = train_data[-limit:]
|
train_data = train_data[-limit:]
|
||||||
|
|
||||||
texts, labels = zip(*train_data)
|
texts, labels = zip(*train_data)
|
||||||
cats = [(['POSITIVE'] if y else []) for y in labels]
|
cats = [{'POSITIVE': bool(y)} for y in labels]
|
||||||
|
|
||||||
split = int(len(train_data) * 0.8)
|
split = int(len(train_data) * 0.8)
|
||||||
|
|
||||||
|
@ -97,7 +98,7 @@ def main(model_loc=None):
|
||||||
textcat = TextCategorizer(tokenizer.vocab, labels=['POSITIVE'])
|
textcat = TextCategorizer(tokenizer.vocab, labels=['POSITIVE'])
|
||||||
|
|
||||||
print("Load IMDB data")
|
print("Load IMDB data")
|
||||||
(train_texts, train_cats), (dev_texts, dev_cats) = load_data(limit=1000)
|
(train_texts, train_cats), (dev_texts, dev_cats) = load_data(limit=2000)
|
||||||
|
|
||||||
print("Itn.\tLoss\tP\tR\tF")
|
print("Itn.\tLoss\tP\tR\tF")
|
||||||
progress = '{i:d} {loss:.3f} {textcat_p:.3f} {textcat_r:.3f} {textcat_f:.3f}'
|
progress = '{i:d} {loss:.3f} {textcat_p:.3f} {textcat_r:.3f} {textcat_f:.3f}'
|
||||||
|
|
|
@ -387,7 +387,7 @@ cdef class GoldParse:
|
||||||
|
|
||||||
def __init__(self, doc, annot_tuples=None, words=None, tags=None, heads=None,
|
def __init__(self, doc, annot_tuples=None, words=None, tags=None, heads=None,
|
||||||
deps=None, entities=None, make_projective=False,
|
deps=None, entities=None, make_projective=False,
|
||||||
cats=tuple()):
|
cats=None):
|
||||||
"""Create a GoldParse.
|
"""Create a GoldParse.
|
||||||
|
|
||||||
doc (Doc): The document the annotations refer to.
|
doc (Doc): The document the annotations refer to.
|
||||||
|
@ -398,12 +398,15 @@ cdef class GoldParse:
|
||||||
entities (iterable): A sequence of named entity annotations, either as
|
entities (iterable): A sequence of named entity annotations, either as
|
||||||
BILUO tag strings, or as `(start_char, end_char, label)` tuples,
|
BILUO tag strings, or as `(start_char, end_char, label)` tuples,
|
||||||
representing the entity positions.
|
representing the entity positions.
|
||||||
cats (iterable): A sequence of labels for text classification. Each
|
cats (dict): Labels for text classification. Each key in the dictionary
|
||||||
label may be a string or an int, or a `(start_char, end_char, label)`
|
may be a string or an int, or a `(start_char, end_char, label)`
|
||||||
tuple, indicating that the label is applied to only part of the
|
tuple, indicating that the label is applied to only part of the
|
||||||
document (usually a sentence). Unlike entity annotations, label
|
document (usually a sentence). Unlike entity annotations, label
|
||||||
annotations can overlap, i.e. a single word can be covered by
|
annotations can overlap, i.e. a single word can be covered by
|
||||||
multiple labelled spans.
|
multiple labelled spans. The TextCategorizer component expects
|
||||||
|
true examples of a label to have the value 1.0, and negative examples
|
||||||
|
of a label to have the value 0.0. Labels not in the dictionary are
|
||||||
|
treated as missing -- the gradient for those labels will be zero.
|
||||||
RETURNS (GoldParse): The newly constructed object.
|
RETURNS (GoldParse): The newly constructed object.
|
||||||
"""
|
"""
|
||||||
if words is None:
|
if words is None:
|
||||||
|
@ -434,7 +437,7 @@ cdef class GoldParse:
|
||||||
self.c.sent_start = <int*>self.mem.alloc(len(doc), sizeof(int))
|
self.c.sent_start = <int*>self.mem.alloc(len(doc), sizeof(int))
|
||||||
self.c.ner = <Transition*>self.mem.alloc(len(doc), sizeof(Transition))
|
self.c.ner = <Transition*>self.mem.alloc(len(doc), sizeof(Transition))
|
||||||
|
|
||||||
self.cats = list(cats)
|
self.cats = {} if cats is None else dict(cats)
|
||||||
self.words = [None] * len(doc)
|
self.words = [None] * len(doc)
|
||||||
self.tags = [None] * len(doc)
|
self.tags = [None] * len(doc)
|
||||||
self.heads = [None] * len(doc)
|
self.heads = [None] * len(doc)
|
||||||
|
|
|
@ -551,7 +551,6 @@ class NeuralLabeller(NeuralTagger):
|
||||||
label = self.make_label(i, words, tags, heads, deps, ents)
|
label = self.make_label(i, words, tags, heads, deps, ents)
|
||||||
if label is not None and label not in self.labels:
|
if label is not None and label not in self.labels:
|
||||||
self.labels[label] = len(self.labels)
|
self.labels[label] = len(self.labels)
|
||||||
print(len(self.labels))
|
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
token_vector_width = util.env_opt('token_vector_width')
|
token_vector_width = util.env_opt('token_vector_width')
|
||||||
self.model = chain(
|
self.model = chain(
|
||||||
|
@ -720,11 +719,17 @@ class TextCategorizer(BaseThincComponent):
|
||||||
|
|
||||||
def get_loss(self, docs, golds, scores):
|
def get_loss(self, docs, golds, scores):
|
||||||
truths = numpy.zeros((len(golds), len(self.labels)), dtype='f')
|
truths = numpy.zeros((len(golds), len(self.labels)), dtype='f')
|
||||||
|
not_missing = numpy.ones((len(golds), len(self.labels)), dtype='f')
|
||||||
for i, gold in enumerate(golds):
|
for i, gold in enumerate(golds):
|
||||||
for j, label in enumerate(self.labels):
|
for j, label in enumerate(self.labels):
|
||||||
truths[i, j] = label in gold.cats
|
if label in gold.cats:
|
||||||
|
truths[i, j] = gold.cats[label]
|
||||||
|
else:
|
||||||
|
not_missing[i, j] = 0.
|
||||||
truths = self.model.ops.asarray(truths)
|
truths = self.model.ops.asarray(truths)
|
||||||
|
not_missing = self.model.ops.asarray(not_missing)
|
||||||
d_scores = (scores-truths) / scores.shape[0]
|
d_scores = (scores-truths) / scores.shape[0]
|
||||||
|
d_scores *= not_missing
|
||||||
mean_square_error = ((scores-truths)**2).sum(axis=1).mean()
|
mean_square_error = ((scores-truths)**2).sum(axis=1).mean()
|
||||||
return mean_square_error, d_scores
|
return mean_square_error, d_scores
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user