From 7de3b129ab6978259ce7c6ee50495ca5016c78f9 Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Tue, 30 Jul 2019 14:58:01 +0200 Subject: [PATCH] Resolve edge case when calling textcat.predict with empty doc (#4035) * resolve edge case where no doc has tokens when calling textcat.predict * more explicit value test --- spacy/pipeline/pipes.pyx | 9 +++- spacy/tests/regression/test_issue4030.py | 57 ++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 spacy/tests/regression/test_issue4030.py diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index 3b5e3d41c..ba1fca24e 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -942,9 +942,16 @@ class TextCategorizer(Pipe): def predict(self, docs): self.require_model() + tensors = [doc.tensor for doc in docs] + + if not any(len(doc) for doc in docs): + # Handle cases where there are no tokens in any docs. + xp = get_array_module(tensors) + scores = xp.zeros((len(docs), len(self.labels))) + return scores, tensors + scores = self.model(docs) scores = self.model.ops.asarray(scores) - tensors = [doc.tensor for doc in docs] return scores, tensors def set_annotations(self, docs, scores, tensors=None): diff --git a/spacy/tests/regression/test_issue4030.py b/spacy/tests/regression/test_issue4030.py new file mode 100644 index 000000000..c331fa1d2 --- /dev/null +++ b/spacy/tests/regression/test_issue4030.py @@ -0,0 +1,57 @@ +# coding: utf8 +from __future__ import unicode_literals + +import spacy +from spacy.util import minibatch, compounding + + +def test_issue4030(): + """ Test whether textcat works fine with empty doc """ + unique_classes = ["offensive", "inoffensive"] + x_train = [ + "This is an offensive text", + "This is the second offensive text", + "inoff", + ] + y_train = ["offensive", "offensive", "inoffensive"] + + # preparing the data + pos_cats = list() + for train_instance in y_train: + pos_cats.append({label: label == train_instance for label in unique_classes}) + train_data = list(zip(x_train, [{"cats": cats} for cats in pos_cats])) + + # set up the spacy model with a text categorizer component + nlp = spacy.blank("en") + + textcat = nlp.create_pipe( + "textcat", + config={"exclusive_classes": True, "architecture": "bow", "ngram_size": 2}, + ) + + for label in unique_classes: + textcat.add_label(label) + nlp.add_pipe(textcat, last=True) + + # training the network + other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "textcat"] + with nlp.disable_pipes(*other_pipes): + optimizer = nlp.begin_training() + for i in range(3): + losses = {} + batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001)) + + for batch in batches: + texts, annotations = zip(*batch) + nlp.update( + docs=texts, + golds=annotations, + sgd=optimizer, + drop=0.1, + losses=losses, + ) + + # processing of an empty doc should result in 0.0 for all categories + doc = nlp("") + assert doc.cats["offensive"] == 0.0 + assert doc.cats["inoffensive"] == 0.0