more docstring and fix negative_label

This commit is contained in:
kadarakos 2023-02-01 11:16:34 +00:00
parent edf9134e45
commit 5ccb154972

View File

@ -1,4 +1,4 @@
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast, Union
from dataclasses import dataclass from dataclasses import dataclass
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
from thinc.api import Optimizer from thinc.api import Optimizer
@ -309,7 +309,12 @@ class SpanCategorizer(TrainablePipe):
threshold: Optional[float] = None, threshold: Optional[float] = None,
scorer: Optional[Callable] = spancat_score, scorer: Optional[Callable] = spancat_score,
) -> None: ) -> None:
"""Initialize the span categorizer. """Initialize the multilabel or multiclass span categorizer.
The 'single_label' argument configures whether the component
should only produce one label per span (multiclass) or if it
can produce multiple labels per span (multilabel). In the
multilabel case the classification layer is expected to be
Logistic and Softmax in the multiclass case.
vocab (Vocab): The shared vocabulary. vocab (Vocab): The shared vocabulary.
model (thinc.api.Model): The Thinc Model powering the pipeline component. model (thinc.api.Model): The Thinc Model powering the pipeline component.
suggester (Callable[[Iterable[Doc], Optional[Ops]], Ragged]): A function that suggests spans. suggester (Callable[[Iterable[Doc], Optional[Ops]], Ragged]): A function that suggests spans.
@ -321,16 +326,23 @@ class SpanCategorizer(TrainablePipe):
During initialization and training, the component will look for During initialization and training, the component will look for
spans on the reference document under the same key. Defaults to spans on the reference document under the same key. Defaults to
`"spans"`. `"spans"`.
single_label (bool): Whether to configure SpanCategorizer to produce
a single label per span. In this case its expected that the scorer
layer is Softmax. Otherwise its expected to be Logistic. When single_label
is true the SpanCategorizer internally has a negative-label indicating
that a span should not receive any of the labels found in the corpus.
threshold (Optional[float]): Minimum probability to consider a prediction threshold (Optional[float]): Minimum probability to consider a prediction
positive. Spans with a positive prediction will be saved on the Doc. positive in the multilabel usecase.Defaults to 0.5 when single_label is
Defaults to 0.5. False otherwise its None. Spans with a positive prediction will be saved on the Doc.
max_positive (Optional[int]): Maximum number of labels to consider max_positive (Optional[int]): Maximum number of labels to consider
positive per span. Defaults to None, indicating no limit. positive per span. Defaults to None, indicating no limit. This is
unused when single_label is True.
negative_weight (float): Multiplier for the loss terms. negative_weight (float): Multiplier for the loss terms.
Can be used to downweight the negative samples if there are too many. Can be used to downweight the negative samples if there are too many
when single_label is True. Otherwise its unused.
allow_overlap (bool): If True the data is assumed to contain overlapping spans. allow_overlap (bool): If True the data is assumed to contain overlapping spans.
Otherwise it produces non-overlapping spans greedily prioritizing Otherwise it produces non-overlapping spans greedily prioritizing
higher assigned label scores. higher assigned label scores. Only used when single_label is True.
scorer (Optional[Callable]): The scoring method. Defaults to scorer (Optional[Callable]): The scoring method. Defaults to
Scorer.score_spans for the Doc.spans[spans_key] with overlapping Scorer.score_spans for the Doc.spans[spans_key] with overlapping
spans allowed. spans allowed.
@ -422,12 +434,12 @@ class SpanCategorizer(TrainablePipe):
return len(self.labels) return len(self.labels)
@property @property
def _negative_label(self) -> int: def _negative_label(self) -> Union[int, None]:
"""RETURNS (int): Index of the negative label.""" """RETURNS (Union[int, None]): Index of the negative label."""
if self.single_label: if self.single_label:
return -1
else:
return len(self.label_data) return len(self.label_data)
else:
return None
def predict(self, docs: Iterable[Doc]): def predict(self, docs: Iterable[Doc]):
"""Apply the pipeline's model to a batch of docs, without modifying them. """Apply the pipeline's model to a batch of docs, without modifying them.