Intervals to private and document 'name' param

This commit is contained in:
kadarakos 2023-03-03 15:51:57 +00:00
parent 6d67ab7670
commit 854d1614a9

View File

@ -155,6 +155,8 @@ def make_spancat(
parts: a suggester function that proposes candidate spans, and a labeller parts: a suggester function that proposes candidate spans, and a labeller
model that predicts one or more labels for each span. model that predicts one or more labels for each span.
name (str): The component instance name, used to add entries to the
losses during training.
suggester (Callable[[Iterable[Doc], Optional[Ops]], Ragged]): A function that suggests spans. suggester (Callable[[Iterable[Doc], Optional[Ops]], Ragged]): A function that suggests spans.
Spans are returned as a ragged array with two integer columns, for the Spans are returned as a ragged array with two integer columns, for the
start and end positions. start and end positions.
@ -218,6 +220,8 @@ def make_spancat_singlelabel(
parts: a suggester function that proposes candidate spans, and a labeller parts: a suggester function that proposes candidate spans, and a labeller
model that predicts one or more labels for each span. model that predicts one or more labels for each span.
name (str): The component instance name, used to add entries to the
losses during training.
suggester (Callable[[Iterable[Doc], Optional[Ops]], Ragged]): A function that suggests spans. suggester (Callable[[Iterable[Doc], Optional[Ops]], Ragged]): A function that suggests spans.
Spans are returned as a ragged array with two integer columns, for the Spans are returned as a ragged array with two integer columns, for the
start and end positions. start and end positions.
@ -271,7 +275,7 @@ def make_spancat_scorer():
@dataclass @dataclass
class Intervals: class _Intervals:
""" """
Helper class to avoid storing overlapping spans. Helper class to avoid storing overlapping spans.
""" """
@ -729,7 +733,7 @@ class SpanCategorizer(TrainablePipe):
predicted = predicted[sort_idx] predicted = predicted[sort_idx]
indices = indices[sort_idx] indices = indices[sort_idx]
keeps = keeps[sort_idx] keeps = keeps[sort_idx]
seen = Intervals() seen = _Intervals()
spans = SpanGroup(doc, name=self.key) spans = SpanGroup(doc, name=self.key)
attrs_scores = [] attrs_scores = []
for i in range(indices.shape[0]): for i in range(indices.shape[0]):