From 391b512afd6a77f7f89c2c931a6e5bb73922636e Mon Sep 17 00:00:00 2001 From: svlandeg Date: Thu, 27 May 2021 16:36:46 +0200 Subject: [PATCH] fix types of fwd functions --- spacy/ml/models/coref.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index c9b0a1b0f..16376bba2 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -133,7 +133,7 @@ def build_span_embedder( def span_embeddings_forward( model, inputs: Tuple[List[Floats2d], List[Doc]], is_train -) -> SpanEmbeddings: +) -> Tuple[SpanEmbeddings, Callable]: ops = model.ops xp = ops.xp @@ -223,7 +223,7 @@ def build_coarse_pruner( def coarse_prune( model, inputs: Tuple[Floats1d, SpanEmbeddings], is_train -) -> SpanEmbeddings: +) -> Tuple[Tuple[Floats1d, SpanEmbeddings], Callable]: """Given scores for mention, output the top non-crossing mentions. Mentions can contain other mentions, but candidate mentions cannot cross each other. @@ -320,7 +320,7 @@ def build_ant_scorer( def ant_scorer_forward( model, inputs: Tuple[Floats1d, SpanEmbeddings], is_train -) -> Tuple[List[Tuple[Floats2d, Ints2d]], Ints2d]: +) -> Tuple[Tuple[List[Tuple[Floats2d, Ints2d]], Ints2d], Callable]: ops = model.ops xp = ops.xp