diff --git a/spacy/pipeline/span_predictor.py b/spacy/pipeline/span_predictor.py index 99a1f7ef6..beec67473 100644 --- a/spacy/pipeline/span_predictor.py +++ b/spacy/pipeline/span_predictor.py @@ -95,6 +95,8 @@ class SpanPredictor(TrainablePipe): """Pipeline component to resolve one-token spans to full spans. Used in coreference resolution. + + DOCS: https://spacy.io/api/span_predictor """ def __init__( @@ -119,6 +121,14 @@ class SpanPredictor(TrainablePipe): } def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]: + """Apply the pipeline's model to a batch of docs, without modifying them. + Return the list of predicted span clusters. + + docs (Iterable[Doc]): The documents to predict. + RETURNS (List[MentionClusters]): The model's prediction for each document. + + DOCS: https://spacy.io/api/span_predictor#predict + """ # for now pretend there's just one doc out = [] @@ -151,6 +161,13 @@ class SpanPredictor(TrainablePipe): return out def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None: + """Modify a batch of Doc objects, using pre-computed scores. + + docs (Iterable[Doc]): The documents to modify. + clusters: The span clusters, produced by SpanPredictor.predict. + + DOCS: https://spacy.io/api/span_predictor#set_annotations + """ for doc, clusters in zip(docs, clusters_by_doc): for ii, cluster in enumerate(clusters): spans = [doc[mm[0] : mm[1]] for mm in cluster] @@ -166,6 +183,15 @@ class SpanPredictor(TrainablePipe): ) -> Dict[str, float]: """Learn from a batch of documents and gold-standard information, updating the pipe's model. Delegates to predict and get_loss. + + examples (Iterable[Example]): A batch of Example objects. + drop (float): The dropout rate. + sgd (thinc.api.Optimizer): The optimizer. + losses (Dict[str, float]): Optional record of the loss during training. + Updated using the component name as the key. + RETURNS (Dict[str, float]): The updated losses dictionary. + + DOCS: https://spacy.io/api/span_predictor#update """ if losses is None: losses = {} @@ -222,6 +248,15 @@ class SpanPredictor(TrainablePipe): examples: Iterable[Example], span_scores: Floats3d, ): + """Find the loss and gradient of loss for the batch of documents and + their predicted scores. + + examples (Iterable[Examples]): The batch of examples. + scores: Scores representing the model's predictions. + RETURNS (Tuple[float, float]): The loss and the gradient. + + DOCS: https://spacy.io/api/span_predictor#get_loss + """ ops = self.model.ops # NOTE This is doing fake batching, and should always get a list of one example @@ -258,6 +293,15 @@ class SpanPredictor(TrainablePipe): *, nlp: Optional[Language] = None, ) -> None: + """Initialize the pipe for training, using a representative set + of data examples. + + get_examples (Callable[[], Iterable[Example]]): Function that + returns a representative sample of gold-standard Example objects. + nlp (Language): The current nlp object the component is part of. + + DOCS: https://spacy.io/api/span_predictor#initialize + """ validate_get_examples(get_examples, "SpanPredictor.initialize") X = []