fix types of fwd functions

This commit is contained in:
svlandeg 2021-05-27 16:36:46 +02:00
parent 04b55bf054
commit 391b512afd

View File

@ -133,7 +133,7 @@ def build_span_embedder(
def span_embeddings_forward( def span_embeddings_forward(
model, inputs: Tuple[List[Floats2d], List[Doc]], is_train model, inputs: Tuple[List[Floats2d], List[Doc]], is_train
) -> SpanEmbeddings: ) -> Tuple[SpanEmbeddings, Callable]:
ops = model.ops ops = model.ops
xp = ops.xp xp = ops.xp
@ -223,7 +223,7 @@ def build_coarse_pruner(
def coarse_prune( def coarse_prune(
model, inputs: Tuple[Floats1d, SpanEmbeddings], is_train model, inputs: Tuple[Floats1d, SpanEmbeddings], is_train
) -> SpanEmbeddings: ) -> Tuple[Tuple[Floats1d, SpanEmbeddings], Callable]:
"""Given scores for mention, output the top non-crossing mentions. """Given scores for mention, output the top non-crossing mentions.
Mentions can contain other mentions, but candidate mentions cannot cross each other. Mentions can contain other mentions, but candidate mentions cannot cross each other.
@ -320,7 +320,7 @@ def build_ant_scorer(
def ant_scorer_forward( def ant_scorer_forward(
model, inputs: Tuple[Floats1d, SpanEmbeddings], is_train model, inputs: Tuple[Floats1d, SpanEmbeddings], is_train
) -> Tuple[List[Tuple[Floats2d, Ints2d]], Ints2d]: ) -> Tuple[Tuple[List[Tuple[Floats2d, Ints2d]], Ints2d], Callable]:
ops = model.ops ops = model.ops
xp = ops.xp xp = ops.xp