mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-25 11:23:40 +03:00
Add scores to output in spancat (#8855)
* Add scores to output in spancat This exposes the scores as an attribute on the SpanGroup. Includes a basic test. * Add basic doc note * Vectorize score calcs * Add "annotation format" section * Update website/docs/api/spancategorizer.md Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Clean up doc section * Ran prettier on docs * Get arrays off the gpu before iterating over them * Remove int() calls Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
parent
a1e9f19460
commit
6029cfc391
|
@ -408,16 +408,21 @@ class SpanCategorizer(TrainablePipe):
|
||||||
spans = SpanGroup(doc, name=self.key)
|
spans = SpanGroup(doc, name=self.key)
|
||||||
max_positive = self.cfg["max_positive"]
|
max_positive = self.cfg["max_positive"]
|
||||||
threshold = self.cfg["threshold"]
|
threshold = self.cfg["threshold"]
|
||||||
|
|
||||||
|
keeps = scores >= threshold
|
||||||
|
ranked = (scores * -1).argsort()
|
||||||
|
keeps[ranked[:, max_positive:]] = False
|
||||||
|
spans.attrs["scores"] = scores[keeps].flatten()
|
||||||
|
|
||||||
|
indices = self.model.ops.to_numpy(indices)
|
||||||
|
keeps = self.model.ops.to_numpy(keeps)
|
||||||
|
|
||||||
for i in range(indices.shape[0]):
|
for i in range(indices.shape[0]):
|
||||||
start = int(indices[i, 0])
|
start = indices[i, 0]
|
||||||
end = int(indices[i, 1])
|
end = indices[i, 1]
|
||||||
positives = []
|
|
||||||
for j, score in enumerate(scores[i]):
|
for j, keep in enumerate(keeps[i]):
|
||||||
if score >= threshold:
|
if keep:
|
||||||
positives.append((score, start, end, labels[j]))
|
spans.append(Span(doc, start, end, label=labels[j]))
|
||||||
positives.sort(reverse=True)
|
|
||||||
if max_positive:
|
|
||||||
positives = positives[:max_positive]
|
|
||||||
for score, start, end, label in positives:
|
|
||||||
spans.append(Span(doc, start, end, label=label))
|
|
||||||
return spans
|
return spans
|
||||||
|
|
|
@ -85,12 +85,14 @@ def test_simple_train():
|
||||||
doc = nlp("I like London and Berlin.")
|
doc = nlp("I like London and Berlin.")
|
||||||
assert doc.spans[spancat.key] == doc.spans[SPAN_KEY]
|
assert doc.spans[spancat.key] == doc.spans[SPAN_KEY]
|
||||||
assert len(doc.spans[spancat.key]) == 2
|
assert len(doc.spans[spancat.key]) == 2
|
||||||
|
assert len(doc.spans[spancat.key].attrs["scores"]) == 2
|
||||||
assert doc.spans[spancat.key][0].text == "London"
|
assert doc.spans[spancat.key][0].text == "London"
|
||||||
scores = nlp.evaluate(get_examples())
|
scores = nlp.evaluate(get_examples())
|
||||||
assert f"spans_{SPAN_KEY}_f" in scores
|
assert f"spans_{SPAN_KEY}_f" in scores
|
||||||
assert scores[f"spans_{SPAN_KEY}_f"] == 1.0
|
assert scores[f"spans_{SPAN_KEY}_f"] == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_ngram_suggester(en_tokenizer):
|
def test_ngram_suggester(en_tokenizer):
|
||||||
# test different n-gram lengths
|
# test different n-gram lengths
|
||||||
for size in [1, 2, 3]:
|
for size in [1, 2, 3]:
|
||||||
|
|
|
@ -13,6 +13,22 @@ A span categorizer consists of two parts: a [suggester function](#suggesters)
|
||||||
that proposes candidate spans, which may or may not overlap, and a labeler model
|
that proposes candidate spans, which may or may not overlap, and a labeler model
|
||||||
that predicts zero or more labels for each candidate.
|
that predicts zero or more labels for each candidate.
|
||||||
|
|
||||||
|
Predicted spans will be saved in a [`SpanGroup`](/api/spangroup) on the doc.
|
||||||
|
Individual span scores can be found in `spangroup.attrs["scores"]`.
|
||||||
|
|
||||||
|
## Assigned Attributes {#assigned-attributes}
|
||||||
|
|
||||||
|
Predictions will be saved to `Doc.spans[spans_key]` as a
|
||||||
|
[`SpanGroup`](/api/spangroup). The scores for the spans in the `SpanGroup` will
|
||||||
|
be saved in `SpanGroup.attrs["scores"]`.
|
||||||
|
|
||||||
|
`spans_key` defaults to `"sc"`, but can be passed as a parameter.
|
||||||
|
|
||||||
|
| Location | Value |
|
||||||
|
| -------------------------------------- | -------------------------------------------------------- |
|
||||||
|
| `Doc.spans[spans_key]` | The annotated spans. ~~SpanGroup~~ |
|
||||||
|
| `Doc.spans[spans_key].attrs["scores"]` | The score for each span in the `SpanGroup`. ~~Floats1d~~ |
|
||||||
|
|
||||||
## Config and implementation {#config}
|
## Config and implementation {#config}
|
||||||
|
|
||||||
The default config is defined by the pipeline component factory and describes
|
The default config is defined by the pipeline component factory and describes
|
||||||
|
|
Loading…
Reference in New Issue
Block a user