Add type annotations for internal models

This commit is contained in:
Paul O'Leary McCann 2022-07-11 20:03:29 +09:00
parent 4d032396b8
commit baeb35f31b
2 changed files with 10 additions and 4 deletions

View File

@ -22,11 +22,11 @@ def build_wl_coref_model(
# pairs to keep per mention after rough scoring # pairs to keep per mention after rough scoring
antecedent_limit: int = 50, antecedent_limit: int = 50,
antecedent_batch_size: int = 512, antecedent_batch_size: int = 512,
nI = None nI=None,
) -> Model[List[Doc], Tuple[Floats2d, Ints2d]]: ) -> Model[List[Doc], Tuple[Floats2d, Ints2d]]:
with Model.define_operators({">>": chain}): with Model.define_operators({">>": chain}):
coref_clusterer = Model( coref_clusterer: Model[List[Floats2d], Tuple[Floats2d, Ints2d]] = Model(
"coref_clusterer", "coref_clusterer",
forward=coref_forward, forward=coref_forward,
init=coref_init, init=coref_init,
@ -81,6 +81,7 @@ def coref_init(model: Model, X=None, Y=None):
def coref_forward(model: Model, X, is_train: bool): def coref_forward(model: Model, X, is_train: bool):
return model.layers[0](X, is_train) return model.layers[0](X, is_train)
def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool): def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool):
# The input here is List[Floats2d], one for each doc # The input here is List[Floats2d], one for each doc
# just use the first # just use the first

View File

@ -25,7 +25,7 @@ def build_span_predictor(
nI = None nI = None
with Model.define_operators({">>": chain, "&": tuplify}): with Model.define_operators({">>": chain, "&": tuplify}):
span_predictor = Model( span_predictor: Model[List[Floats2d], List[Floats2d]] = Model(
"span_predictor", "span_predictor",
forward=span_predictor_forward, forward=span_predictor_forward,
init=span_predictor_init, init=span_predictor_init,
@ -44,6 +44,7 @@ def build_span_predictor(
return model return model
def span_predictor_init(model: Model, X=None, Y=None): def span_predictor_init(model: Model, X=None, Y=None):
if model.layers: if model.layers:
return return
@ -72,9 +73,11 @@ def span_predictor_init(model: Model, X=None, Y=None):
# TODO maybe we need mixed precision and grad scaling? # TODO maybe we need mixed precision and grad scaling?
] ]
def span_predictor_forward(model: Model, X, is_train: bool): def span_predictor_forward(model: Model, X, is_train: bool):
return model.layers[0](X, is_train) return model.layers[0](X, is_train)
def convert_span_predictor_inputs( def convert_span_predictor_inputs(
model: Model, model: Model,
X: Tuple[List[Floats2d], Tuple[List[Ints1d], List[Ints1d]]], X: Tuple[List[Floats2d], Tuple[List[Ints1d], List[Ints1d]]],
@ -95,7 +98,9 @@ def convert_span_predictor_inputs(
else: else:
head_ids_tensor = xp2torch(head_ids[0], requires_grad=False) head_ids_tensor = xp2torch(head_ids[0], requires_grad=False)
argskwargs = ArgsKwargs(args=(sent_ids_tensor, word_features, head_ids_tensor), kwargs={}) argskwargs = ArgsKwargs(
args=(sent_ids_tensor, word_features, head_ids_tensor), kwargs={}
)
return argskwargs, backprop return argskwargs, backprop