Update span predictor docstrings

This commit is contained in:
Paul O'Leary McCann 2022-07-06 17:28:05 +09:00
parent c4de3e51a2
commit 5e405738d2

View File

@ -95,6 +95,8 @@ class SpanPredictor(TrainablePipe):
"""Pipeline component to resolve one-token spans to full spans. """Pipeline component to resolve one-token spans to full spans.
Used in coreference resolution. Used in coreference resolution.
DOCS: https://spacy.io/api/span_predictor
""" """
def __init__( def __init__(
@ -119,6 +121,14 @@ class SpanPredictor(TrainablePipe):
} }
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]: def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
"""Apply the pipeline's model to a batch of docs, without modifying them.
Return the list of predicted span clusters.
docs (Iterable[Doc]): The documents to predict.
RETURNS (List[MentionClusters]): The model's prediction for each document.
DOCS: https://spacy.io/api/span_predictor#predict
"""
# for now pretend there's just one doc # for now pretend there's just one doc
out = [] out = []
@ -151,6 +161,13 @@ class SpanPredictor(TrainablePipe):
return out return out
def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None: def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None:
"""Modify a batch of Doc objects, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify.
clusters: The span clusters, produced by SpanPredictor.predict.
DOCS: https://spacy.io/api/span_predictor#set_annotations
"""
for doc, clusters in zip(docs, clusters_by_doc): for doc, clusters in zip(docs, clusters_by_doc):
for ii, cluster in enumerate(clusters): for ii, cluster in enumerate(clusters):
spans = [doc[mm[0] : mm[1]] for mm in cluster] spans = [doc[mm[0] : mm[1]] for mm in cluster]
@ -166,6 +183,15 @@ class SpanPredictor(TrainablePipe):
) -> Dict[str, float]: ) -> Dict[str, float]:
"""Learn from a batch of documents and gold-standard information, """Learn from a batch of documents and gold-standard information,
updating the pipe's model. Delegates to predict and get_loss. updating the pipe's model. Delegates to predict and get_loss.
examples (Iterable[Example]): A batch of Example objects.
drop (float): The dropout rate.
sgd (thinc.api.Optimizer): The optimizer.
losses (Dict[str, float]): Optional record of the loss during training.
Updated using the component name as the key.
RETURNS (Dict[str, float]): The updated losses dictionary.
DOCS: https://spacy.io/api/span_predictor#update
""" """
if losses is None: if losses is None:
losses = {} losses = {}
@ -222,6 +248,15 @@ class SpanPredictor(TrainablePipe):
examples: Iterable[Example], examples: Iterable[Example],
span_scores: Floats3d, span_scores: Floats3d,
): ):
"""Find the loss and gradient of loss for the batch of documents and
their predicted scores.
examples (Iterable[Examples]): The batch of examples.
scores: Scores representing the model's predictions.
RETURNS (Tuple[float, float]): The loss and the gradient.
DOCS: https://spacy.io/api/span_predictor#get_loss
"""
ops = self.model.ops ops = self.model.ops
# NOTE This is doing fake batching, and should always get a list of one example # NOTE This is doing fake batching, and should always get a list of one example
@ -258,6 +293,15 @@ class SpanPredictor(TrainablePipe):
*, *,
nlp: Optional[Language] = None, nlp: Optional[Language] = None,
) -> None: ) -> None:
"""Initialize the pipe for training, using a representative set
of data examples.
get_examples (Callable[[], Iterable[Example]]): Function that
returns a representative sample of gold-standard Example objects.
nlp (Language): The current nlp object the component is part of.
DOCS: https://spacy.io/api/span_predictor#initialize
"""
validate_get_examples(get_examples, "SpanPredictor.initialize") validate_get_examples(get_examples, "SpanPredictor.initialize")
X = [] X = []