From 445c670a2d537598b3d562fb7f444050164a260b Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Fri, 2 Dec 2022 09:33:52 +0100 Subject: [PATCH] Fix spancat for zero suggestions (#11860) * Add test for spancat predict with zero suggestions * Fix spancat for zero suggestions * Undo changes to extract_spans * Use .sum() as in update --- spacy/pipeline/spancat.py | 5 +++- spacy/tests/pipeline/test_spancat.py | 43 ++++++++++++++++++++++------ 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py index 0a84c72fd..a3388e81a 100644 --- a/spacy/pipeline/spancat.py +++ b/spacy/pipeline/spancat.py @@ -272,7 +272,10 @@ class SpanCategorizer(TrainablePipe): DOCS: https://spacy.io/api/spancategorizer#predict """ indices = self.suggester(docs, ops=self.model.ops) - scores = self.model.predict((docs, indices)) # type: ignore + if indices.lengths.sum() == 0: + scores = self.model.ops.alloc2f(0, 0) + else: + scores = self.model.predict((docs, indices)) # type: ignore return indices, scores def set_candidates( diff --git a/spacy/tests/pipeline/test_spancat.py b/spacy/tests/pipeline/test_spancat.py index 15256a763..e9db983d3 100644 --- a/spacy/tests/pipeline/test_spancat.py +++ b/spacy/tests/pipeline/test_spancat.py @@ -372,24 +372,39 @@ def test_overfitting_IO_overlapping(): def test_zero_suggestions(): - # Test with a suggester that returns 0 suggestions + # Test with a suggester that can return 0 suggestions - @registry.misc("test_zero_suggester") - def make_zero_suggester(): - def zero_suggester(docs, *, ops=None): + @registry.misc("test_mixed_zero_suggester") + def make_mixed_zero_suggester(): + def mixed_zero_suggester(docs, *, ops=None): if ops is None: ops = get_current_ops() - return Ragged( - ops.xp.zeros((0, 0), dtype="i"), ops.xp.zeros((len(docs),), dtype="i") - ) + spans = [] + lengths = [] + for doc in docs: + if len(doc) > 0 and len(doc) % 2 == 0: + spans.append((0, 1)) + lengths.append(1) + else: + lengths.append(0) + spans = ops.asarray2i(spans) + lengths_array = ops.asarray1i(lengths) + if len(spans) > 0: + output = Ragged(ops.xp.vstack(spans), lengths_array) + else: + output = Ragged(ops.xp.zeros((0, 0), dtype="i"), lengths_array) + return output - return zero_suggester + return mixed_zero_suggester fix_random_seed(0) nlp = English() spancat = nlp.add_pipe( "spancat", - config={"suggester": {"@misc": "test_zero_suggester"}, "spans_key": SPAN_KEY}, + config={ + "suggester": {"@misc": "test_mixed_zero_suggester"}, + "spans_key": SPAN_KEY, + }, ) train_examples = make_examples(nlp) optimizer = nlp.initialize(get_examples=lambda: train_examples) @@ -397,6 +412,16 @@ def test_zero_suggestions(): assert set(spancat.labels) == {"LOC", "PERSON"} nlp.update(train_examples, sgd=optimizer) + # empty doc + nlp("") + # single doc with zero suggestions + nlp("one") + # single doc with one suggestion + nlp("two two") + # batch with mixed zero/one suggestions + list(nlp.pipe(["one", "two two", "three three three", "", "four four four four"])) + # batch with no suggestions + list(nlp.pipe(["", "one", "three three three"])) def test_set_candidates():