diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index c9b0a1b0f..16376bba2 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -133,7 +133,7 @@ def build_span_embedder( def span_embeddings_forward( model, inputs: Tuple[List[Floats2d], List[Doc]], is_train -) -> SpanEmbeddings: +) -> Tuple[SpanEmbeddings, Callable]: ops = model.ops xp = ops.xp @@ -223,7 +223,7 @@ def build_coarse_pruner( def coarse_prune( model, inputs: Tuple[Floats1d, SpanEmbeddings], is_train -) -> SpanEmbeddings: +) -> Tuple[Tuple[Floats1d, SpanEmbeddings], Callable]: """Given scores for mention, output the top non-crossing mentions. Mentions can contain other mentions, but candidate mentions cannot cross each other. @@ -320,7 +320,7 @@ def build_ant_scorer( def ant_scorer_forward( model, inputs: Tuple[Floats1d, SpanEmbeddings], is_train -) -> Tuple[List[Tuple[Floats2d, Ints2d]], Ints2d]: +) -> Tuple[Tuple[List[Tuple[Floats2d, Ints2d]], Ints2d], Callable]: ops = model.ops xp = ops.xp