Inherit from SpanCat instead of TrainablePipe

This commit changes the inheritance structure of Exclusive_Spancat,
now it's inheriting from SpanCategorizer than TrainablePipe. This
allows me to remove duplicate methods that are already present in
the parent function.
This commit is contained in:
Lj Miranda 2022-11-02 12:30:41 +08:00
parent bdf2a1d1fe
commit 8548e2c311

View File

@ -13,7 +13,7 @@ from ..tokens import Doc, Span, SpanGroup
from ..training import Example, validate_examples from ..training import Example, validate_examples
from ..vocab import Vocab from ..vocab import Vocab
from .spancat import spancat_score, build_ngram_suggester from .spancat import spancat_score, build_ngram_suggester
from .trainable_pipe import TrainablePipe from .spancat import SpanCategorizer
spancat_exclusive_default_config = """ spancat_exclusive_default_config = """
@ -71,7 +71,7 @@ def make_spancat(
scorer: Optional[Callable], scorer: Optional[Callable],
negative_weight: float = 1.0, negative_weight: float = 1.0,
allow_overlap: bool = True, allow_overlap: bool = True,
) -> "SpanCategorizerExclusive": ) -> "Exclusive_SpanCategorizer":
"""Create a SpanCategorizerExclusive component. The span categorizer consists of two """Create a SpanCategorizerExclusive component. The span categorizer consists of two
parts: a suggester function that proposes candidate spans, and a labeller parts: a suggester function that proposes candidate spans, and a labeller
model that predicts a single label for each span. model that predicts a single label for each span.
@ -94,7 +94,7 @@ def make_spancat(
allow_overlap (bool): If True the data is assumed to allow_overlap (bool): If True the data is assumed to
contain overlapping spans. contain overlapping spans.
""" """
return SpanCategorizerExclusive( return Exclusive_SpanCategorizer(
nlp.vocab, nlp.vocab,
suggester=suggester, suggester=suggester,
model=model, model=model,
@ -127,8 +127,8 @@ class Ranges:
return False return False
class SpanCategorizerExclusive(TrainablePipe): class Exclusive_SpanCategorizer(SpanCategorizer):
"""Pipeline component to label spans of text. """Pipeline component to label non-overlapping spans of text.
DOCS: https://spacy.io/api/spancategorizerexclusive DOCS: https://spacy.io/api/spancategorizerexclusive
""" """
@ -176,47 +176,6 @@ class SpanCategorizerExclusive(TrainablePipe):
self.name = name self.name = name
self.scorer = scorer self.scorer = scorer
@property
def key(self) -> str:
"""Key of the doc.spans dict to save the spans under. During
initialization and training, the component will look for spans on the
reference document under the same key.
"""
return str(self.cfg["spans_key"])
def add_label(self, label: str) -> int:
"""Add a new label to the pipe.
label (str): The label to add.
RETURNS (int): 0 if label is already present, otherwise 1.
DOCS: https://spacy.io/api/spancategorizerexclusive#add_label
"""
if not isinstance(label, str):
raise ValueError(Errors.E187)
if label in self.labels:
return 0
self._allow_extra_label()
self.cfg["labels"].append(label) # type: ignore
self.vocab.strings.add(label)
return 1
@property
def labels(self) -> Tuple[str]:
"""RETURNS (Tuple[str]): The labels currently added to the component.
DOCS: https://spacy.io/api/spancategorizerexclusive#labels
"""
return tuple(self.cfg["labels"]) # type: ignore
@property
def label_data(self) -> List[str]:
"""RETURNS (List[str]): Information about the component's labels.
DOCS: https://spacy.io/api/spancategorizerexclusive#label_data
"""
return list(self.labels)
@property @property
def label_map(self) -> Dict[str, int]: def label_map(self) -> Dict[str, int]:
"""RETURNS (Dict[str, int]): The label map.""" """RETURNS (Dict[str, int]): The label map."""
@ -232,37 +191,6 @@ class SpanCategorizerExclusive(TrainablePipe):
"""RETURNS (int): Number of labels including the negative label.""" """RETURNS (int): Number of labels including the negative label."""
return len(self.label_data) + 1 return len(self.label_data) + 1
def predict(self, docs: Iterable[Doc]):
"""Apply the pipeline's model to a batch of docs, without modifying them.
docs (Iterable[Doc]): The documents to predict.
RETURNS: The models prediction for each document.
DOCS: https://spacy.io/api/spancategorizerexclusive#predict
"""
indices = self.suggester(docs, ops=self.model.ops)
scores = self.model.predict((docs, indices)) # type: ignore
return indices, scores
def set_candidates(
self, docs: Iterable[Doc], *, candidates_key: str = "candidates"
) -> None:
"""Use the spancat suggester to add a list of span candidates to a
list of docs. Intended to be used for debugging purposes.
docs (Iterable[Doc]): The documents to modify.
candidates_key (str): Key of the Doc.spans dict to save the
candidate spans under.
DOCS: https://spacy.io/api/spancategorizerexclusive#set_candidates
"""
suggester_output = self.suggester(docs, ops=self.model.ops)
for candidates, doc in zip(suggester_output, docs): # type: ignore
doc.spans[candidates_key] = []
for index in candidates.dataXd:
doc.spans[candidates_key].append(doc[index[0] : index[1]])
def set_annotations(self, docs: Iterable[Doc], indices_scores) -> None: def set_annotations(self, docs: Iterable[Doc], indices_scores) -> None:
"""Modify a batch of Doc objects, using pre-computed scores. """Modify a batch of Doc objects, using pre-computed scores.
@ -286,47 +214,6 @@ class SpanCategorizerExclusive(TrainablePipe):
) )
offset += indices.lengths[i] offset += indices.lengths[i]
def update(
self,
examples: Iterable[Example],
*,
drop: float = 0.0,
sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None,
) -> Dict[str, float]:
"""Learn from a batch of documents and gold-standard information,
updating the pipe's model. Delegates to predict and get_loss.
examples (Iterable[Example]): A batch of Example objects.
drop (float): The dropout rate.
sgd (thinc.api.Optimizer): The optimizer.
losses (Dict[str, float]): Optional record of the loss during training.
Updated using the component name as the key.
RETURNS (Dict[str, float]): The updated losses dictionary.
DOCS: https://spacy.io/api/spancategorizerexclusive#update
"""
if losses is None:
losses = {}
losses.setdefault(self.name, 0.0)
validate_examples(examples, "SpanCategorizer.update")
self._validate_categories(examples)
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
# Handle cases where there are no tokens in any docs.
return losses
docs = [eg.predicted for eg in examples]
spans = self.suggester(docs, ops=self.model.ops)
if spans.lengths.sum() == 0:
return losses
set_dropout_rate(self.model, drop)
scores, backprop_scores = self.model.begin_update((docs, spans))
loss, d_scores = self.get_loss(examples, (spans, scores))
backprop_scores(d_scores) # type: ignore
if sgd is not None:
self.finish_update(sgd)
losses[self.name] += loss
return losses
def get_loss( def get_loss(
self, examples: Iterable[Example], spans_scores: Tuple[Ragged, Floats2d] self, examples: Iterable[Example], spans_scores: Tuple[Ragged, Floats2d]
) -> Tuple[float, float]: ) -> Tuple[float, float]: