From dceeb02b9446c29edfad848bb4398c2c21dbf018 Mon Sep 17 00:00:00 2001 From: kadarakos Date: Tue, 31 Jan 2023 16:27:26 +0000 Subject: [PATCH] wire up different make_spangroups for single and multilabel --- spacy/pipeline/spancat.py | 21 ++++++++++++++++++--- spacy/tests/pipeline/test_spancat.py | 2 +- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py index 28a527104..ec1b252f5 100644 --- a/spacy/pipeline/spancat.py +++ b/spacy/pipeline/spancat.py @@ -237,6 +237,7 @@ def make_spancat_singlelabel( allow_overlap=allow_overlap, name=name, scorer=scorer, + single_label=True ) @@ -463,9 +464,23 @@ class SpanCategorizer(TrainablePipe): offset = 0 for i, doc in enumerate(docs): indices_i = indices[i].dataXd - doc.spans[self.key] = self._make_span_group( - doc, indices_i, scores[offset : offset + indices.lengths[i]], labels # type: ignore[arg-type] - ) + if self.single_label: + allow_overlap = cast(bool, self.cfg["allow_overlap"]) + doc.spans[self.key] = self._make_span_group_singlelabel( + doc, + indices_i, + scores[offset : offset + indices.lengths[i]], + labels, # type: ignore[arg-type] + allow_overlap + ) + else: + doc.spans[self.key] = self._make_span_group_multilabel( + doc, + indices_i, + scores[offset : offset + indices.lengths[i]], + labels, # type: ignore[arg-type] + ) + offset += indices.lengths[i] def update( diff --git a/spacy/tests/pipeline/test_spancat.py b/spacy/tests/pipeline/test_spancat.py index 0d9f0fe89..56d44b7be 100644 --- a/spacy/tests/pipeline/test_spancat.py +++ b/spacy/tests/pipeline/test_spancat.py @@ -129,7 +129,7 @@ def test_make_spangroup(max_positive, nr_results): scores = numpy.asarray( [[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f" ) - spangroup = spancat._make_span_group(doc, indices, scores, labels) + spangroup = spancat._make_span_group_multilabel(doc, indices, scores, labels) assert len(spangroup) == nr_results # first span is always the second token "London"