mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 16:22:29 +03:00
wire up different make_spangroups for single and multilabel
This commit is contained in:
parent
52e7324df4
commit
dceeb02b94
|
@ -237,6 +237,7 @@ def make_spancat_singlelabel(
|
||||||
allow_overlap=allow_overlap,
|
allow_overlap=allow_overlap,
|
||||||
name=name,
|
name=name,
|
||||||
scorer=scorer,
|
scorer=scorer,
|
||||||
|
single_label=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -463,9 +464,23 @@ class SpanCategorizer(TrainablePipe):
|
||||||
offset = 0
|
offset = 0
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
indices_i = indices[i].dataXd
|
indices_i = indices[i].dataXd
|
||||||
doc.spans[self.key] = self._make_span_group(
|
if self.single_label:
|
||||||
doc, indices_i, scores[offset : offset + indices.lengths[i]], labels # type: ignore[arg-type]
|
allow_overlap = cast(bool, self.cfg["allow_overlap"])
|
||||||
|
doc.spans[self.key] = self._make_span_group_singlelabel(
|
||||||
|
doc,
|
||||||
|
indices_i,
|
||||||
|
scores[offset : offset + indices.lengths[i]],
|
||||||
|
labels, # type: ignore[arg-type]
|
||||||
|
allow_overlap
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
doc.spans[self.key] = self._make_span_group_multilabel(
|
||||||
|
doc,
|
||||||
|
indices_i,
|
||||||
|
scores[offset : offset + indices.lengths[i]],
|
||||||
|
labels, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
offset += indices.lengths[i]
|
offset += indices.lengths[i]
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
|
|
|
@ -129,7 +129,7 @@ def test_make_spangroup(max_positive, nr_results):
|
||||||
scores = numpy.asarray(
|
scores = numpy.asarray(
|
||||||
[[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f"
|
[[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f"
|
||||||
)
|
)
|
||||||
spangroup = spancat._make_span_group(doc, indices, scores, labels)
|
spangroup = spancat._make_span_group_multilabel(doc, indices, scores, labels)
|
||||||
assert len(spangroup) == nr_results
|
assert len(spangroup) == nr_results
|
||||||
|
|
||||||
# first span is always the second token "London"
|
# first span is always the second token "London"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user