wire up different make_spangroups for single and multilabel

This commit is contained in:
kadarakos 2023-01-31 16:27:26 +00:00
parent 52e7324df4
commit dceeb02b94
2 changed files with 19 additions and 4 deletions

View File

@ -237,6 +237,7 @@ def make_spancat_singlelabel(
allow_overlap=allow_overlap,
name=name,
scorer=scorer,
single_label=True
)
@ -463,9 +464,23 @@ class SpanCategorizer(TrainablePipe):
offset = 0
for i, doc in enumerate(docs):
indices_i = indices[i].dataXd
doc.spans[self.key] = self._make_span_group(
doc, indices_i, scores[offset : offset + indices.lengths[i]], labels # type: ignore[arg-type]
)
if self.single_label:
allow_overlap = cast(bool, self.cfg["allow_overlap"])
doc.spans[self.key] = self._make_span_group_singlelabel(
doc,
indices_i,
scores[offset : offset + indices.lengths[i]],
labels, # type: ignore[arg-type]
allow_overlap
)
else:
doc.spans[self.key] = self._make_span_group_multilabel(
doc,
indices_i,
scores[offset : offset + indices.lengths[i]],
labels, # type: ignore[arg-type]
)
offset += indices.lengths[i]
def update(

View File

@ -129,7 +129,7 @@ def test_make_spangroup(max_positive, nr_results):
scores = numpy.asarray(
[[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f"
)
spangroup = spancat._make_span_group(doc, indices, scores, labels)
spangroup = spancat._make_span_group_multilabel(doc, indices, scores, labels)
assert len(spangroup) == nr_results
# first span is always the second token "London"