diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py index 8d1be06c3..e2f53be0d 100644 --- a/spacy/pipeline/spancat.py +++ b/spacy/pipeline/spancat.py @@ -408,16 +408,21 @@ class SpanCategorizer(TrainablePipe): spans = SpanGroup(doc, name=self.key) max_positive = self.cfg["max_positive"] threshold = self.cfg["threshold"] + + keeps = scores >= threshold + ranked = (scores * -1).argsort() + keeps[ranked[:, max_positive:]] = False + spans.attrs["scores"] = scores[keeps].flatten() + + indices = self.model.ops.to_numpy(indices) + keeps = self.model.ops.to_numpy(keeps) + for i in range(indices.shape[0]): - start = int(indices[i, 0]) - end = int(indices[i, 1]) - positives = [] - for j, score in enumerate(scores[i]): - if score >= threshold: - positives.append((score, start, end, labels[j])) - positives.sort(reverse=True) - if max_positive: - positives = positives[:max_positive] - for score, start, end, label in positives: - spans.append(Span(doc, start, end, label=label)) + start = indices[i, 0] + end = indices[i, 1] + + for j, keep in enumerate(keeps[i]): + if keep: + spans.append(Span(doc, start, end, label=labels[j])) + return spans diff --git a/spacy/tests/pipeline/test_spancat.py b/spacy/tests/pipeline/test_spancat.py index 6a5ae2c66..bf1c00041 100644 --- a/spacy/tests/pipeline/test_spancat.py +++ b/spacy/tests/pipeline/test_spancat.py @@ -85,12 +85,14 @@ def test_simple_train(): doc = nlp("I like London and Berlin.") assert doc.spans[spancat.key] == doc.spans[SPAN_KEY] assert len(doc.spans[spancat.key]) == 2 + assert len(doc.spans[spancat.key].attrs["scores"]) == 2 assert doc.spans[spancat.key][0].text == "London" scores = nlp.evaluate(get_examples()) assert f"spans_{SPAN_KEY}_f" in scores assert scores[f"spans_{SPAN_KEY}_f"] == 1.0 + def test_ngram_suggester(en_tokenizer): # test different n-gram lengths for size in [1, 2, 3]: diff --git a/website/docs/api/spancategorizer.md b/website/docs/api/spancategorizer.md index 57395846d..d37b2f23a 100644 --- a/website/docs/api/spancategorizer.md +++ b/website/docs/api/spancategorizer.md @@ -13,6 +13,22 @@ A span categorizer consists of two parts: a [suggester function](#suggesters) that proposes candidate spans, which may or may not overlap, and a labeler model that predicts zero or more labels for each candidate. +Predicted spans will be saved in a [`SpanGroup`](/api/spangroup) on the doc. +Individual span scores can be found in `spangroup.attrs["scores"]`. + +## Assigned Attributes {#assigned-attributes} + +Predictions will be saved to `Doc.spans[spans_key]` as a +[`SpanGroup`](/api/spangroup). The scores for the spans in the `SpanGroup` will +be saved in `SpanGroup.attrs["scores"]`. + +`spans_key` defaults to `"sc"`, but can be passed as a parameter. + +| Location | Value | +| -------------------------------------- | -------------------------------------------------------- | +| `Doc.spans[spans_key]` | The annotated spans. ~~SpanGroup~~ | +| `Doc.spans[spans_key].attrs["scores"]` | The score for each span in the `SpanGroup`. ~~Floats1d~~ | + ## Config and implementation {#config} The default config is defined by the pipeline component factory and describes