This commit is contained in:
kadarakos 2023-01-31 16:30:12 +00:00
parent dceeb02b94
commit 8a807ef1dd

View File

@ -71,7 +71,9 @@ depth = 4
""" """
DEFAULT_SPANCAT_MODEL = Config().from_str(spancat_default_config)["model"] DEFAULT_SPANCAT_MODEL = Config().from_str(spancat_default_config)["model"]
DEFAULT_SPANCAT_SINGLELABEL_MODEL = Config().from_str(spancat_singlelabel_default_config)["model"] DEFAULT_SPANCAT_SINGLELABEL_MODEL = Config().from_str(
spancat_singlelabel_default_config
)["model"]
@runtime_checkable @runtime_checkable
@ -191,7 +193,7 @@ def make_spancat(
"negative_weight": 1.0, "negative_weight": 1.0,
"suggester": {"@misc": "spacy.ngram_suggester.v1", "sizes": [1, 2, 3]}, "suggester": {"@misc": "spacy.ngram_suggester.v1", "sizes": [1, 2, 3]},
"scorer": {"@scorers": "spacy.spancat_scorer.v1"}, "scorer": {"@scorers": "spacy.spancat_scorer.v1"},
"allow_overlap": True "allow_overlap": True,
}, },
default_score_weights={"spans_sc_f": 1.0, "spans_sc_p": 0.0, "spans_sc_r": 0.0}, default_score_weights={"spans_sc_f": 1.0, "spans_sc_p": 0.0, "spans_sc_r": 0.0},
) )
@ -237,7 +239,7 @@ def make_spancat_singlelabel(
allow_overlap=allow_overlap, allow_overlap=allow_overlap,
name=name, name=name,
scorer=scorer, scorer=scorer,
single_label=True single_label=True,
) )
@ -332,7 +334,7 @@ class SpanCategorizer(TrainablePipe):
"threshold": threshold, "threshold": threshold,
"max_positive": max_positive, "max_positive": max_positive,
"negative_weight": negative_weight, "negative_weight": negative_weight,
"allow_overlap": allow_overlap "allow_overlap": allow_overlap,
} }
self.vocab = vocab self.vocab = vocab
self.suggester = suggester self.suggester = suggester
@ -471,7 +473,7 @@ class SpanCategorizer(TrainablePipe):
indices_i, indices_i,
scores[offset : offset + indices.lengths[i]], scores[offset : offset + indices.lengths[i]],
labels, # type: ignore[arg-type] labels, # type: ignore[arg-type]
allow_overlap allow_overlap,
) )
else: else:
doc.spans[self.key] = self._make_span_group_multilabel( doc.spans[self.key] = self._make_span_group_multilabel(
@ -638,7 +640,7 @@ class SpanCategorizer(TrainablePipe):
indices: Ints2d, indices: Ints2d,
scores: Floats2d, scores: Floats2d,
labels: List[str], labels: List[str],
allow_overlap: bool = True allow_overlap: bool = True,
) -> SpanGroup: ) -> SpanGroup:
spans = SpanGroup(doc, name=self.key) spans = SpanGroup(doc, name=self.key)
max_positive = self.cfg["max_positive"] max_positive = self.cfg["max_positive"]