fix types of fwd functions

This commit is contained in:
svlandeg 2021-05-27 16:36:46 +02:00
parent 04b55bf054
commit 391b512afd

View File

@ -133,7 +133,7 @@ def build_span_embedder(
def span_embeddings_forward(
model, inputs: Tuple[List[Floats2d], List[Doc]], is_train
) -> SpanEmbeddings:
) -> Tuple[SpanEmbeddings, Callable]:
ops = model.ops
xp = ops.xp
@ -223,7 +223,7 @@ def build_coarse_pruner(
def coarse_prune(
model, inputs: Tuple[Floats1d, SpanEmbeddings], is_train
) -> SpanEmbeddings:
) -> Tuple[Tuple[Floats1d, SpanEmbeddings], Callable]:
"""Given scores for mention, output the top non-crossing mentions.
Mentions can contain other mentions, but candidate mentions cannot cross each other.
@ -320,7 +320,7 @@ def build_ant_scorer(
def ant_scorer_forward(
model, inputs: Tuple[Floats1d, SpanEmbeddings], is_train
) -> Tuple[List[Tuple[Floats2d, Ints2d]], Ints2d]:
) -> Tuple[Tuple[List[Tuple[Floats2d, Ints2d]], Ints2d], Callable]:
ops = model.ops
xp = ops.xp