Fix types

mypy now exits without an error, except for two apparently unrelated
ones about setup.py.
This commit is contained in:
Paul O'Leary McCann 2022-07-08 18:29:14 +09:00
parent ce49136458
commit 2eee0d248e
2 changed files with 23 additions and 30 deletions

View File

@ -1,8 +1,8 @@
from typing import List, Tuple
from typing import List, Tuple, Callable, cast
from thinc.api import Model, chain
from thinc.api import PyTorchWrapper, ArgsKwargs
from thinc.types import Floats2d
from thinc.types import Floats2d, Ints2d
from thinc.util import torch, xp2torch, torch2xp
from ...tokens import Doc
@ -23,9 +23,7 @@ def build_wl_coref_model(
antecedent_limit: int = 50,
antecedent_batch_size: int = 512,
tok2vec_size: int = 768, # tok2vec size
):
# TODO add model return types
# dim = tok2vec.maybe_get_dim("n0")
) -> Model[List[Doc], Tuple[Floats2d, Ints2d]]:
with Model.define_operators({">>": chain}):
coref_clusterer = PyTorchWrapper(
@ -45,27 +43,24 @@ def build_wl_coref_model(
return coref_model
def convert_coref_clusterer_inputs(
model: Model, X: List[Floats2d], is_train: bool
):
def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool):
# The input here is List[Floats2d], one for each doc
# just use the first
# TODO real batching
X = X[0]
word_features = xp2torch(X, requires_grad=is_train)
# TODO fix or remove type annotations
def backprop(args: ArgsKwargs): #-> List[Floats2d]:
def backprop(args: ArgsKwargs) -> List[Floats2d]:
# convert to xp and wrap in list
gradients = torch2xp(args.args[0])
gradients = cast(Floats2d, torch2xp(args.args[0]))
return [gradients]
return ArgsKwargs(args=(word_features,), kwargs={}), backprop
def convert_coref_clusterer_outputs(
model: Model, inputs_outputs, is_train: bool
):
model: Model, inputs_outputs, is_train: bool
) -> Tuple[Tuple[Floats2d, Ints2d], Callable]:
_, outputs = inputs_outputs
scores, indices = outputs
@ -76,8 +71,8 @@ def convert_coref_clusterer_outputs(
kwargs={"grad_tensors": [dY_t]},
)
scores_xp = torch2xp(scores)
indices_xp = torch2xp(indices)
scores_xp = cast(Floats2d, torch2xp(scores))
indices_xp = cast(Ints2d, torch2xp(indices))
return (scores_xp, indices_xp), convert_for_torch_backward
@ -115,9 +110,7 @@ class CorefClusterer(torch.nn.Module):
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
pair_emb = dim * 3 + self.pw.shape
self.a_scorer = AnaphoricityScorer(
pair_emb, hidden_size, n_layers, dropout
)
self.a_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout)
self.lstm = torch.nn.LSTM(
input_size=dim,
hidden_size=dim,
@ -156,10 +149,10 @@ class CorefClusterer(torch.nn.Module):
a_scores_lst: List[torch.Tensor] = []
for i in range(0, len(words), batch_size):
pw_batch = pw[i:i + batch_size]
words_batch = words[i:i + batch_size]
top_indices_batch = top_indices[i:i + batch_size]
top_rough_scores_batch = top_rough_scores[i:i + batch_size]
pw_batch = pw[i : i + batch_size]
words_batch = words[i : i + batch_size]
top_indices_batch = top_indices[i : i + batch_size]
top_rough_scores_batch = top_rough_scores[i : i + batch_size]
# a_scores_batch [batch_size, n_ants]
a_scores_batch = self.a_scorer(

View File

@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import List, Tuple, cast
from thinc.api import Model, chain, tuplify
from thinc.api import PyTorchWrapper, ArgsKwargs
@ -42,15 +42,17 @@ def build_span_predictor(
def convert_span_predictor_inputs(
model: Model, X: Tuple[List[Floats2d], Tuple[List[Ints1d], List[Ints1d]]], is_train: bool
model: Model,
X: Tuple[List[Floats2d], Tuple[List[Ints1d], List[Ints1d]]],
is_train: bool,
):
tok2vec, (sent_ids, head_ids) = X
# Normally we should use the input is_train, but for these two it's not relevant
# TODO fix the type here, or remove it
def backprop(args: ArgsKwargs): #-> Tuple[List[Floats2d], None]:
gradients = torch2xp(args.args[1])
def backprop(args: ArgsKwargs) -> Tuple[List[Floats2d], None]:
gradients = cast(Floats2d, torch2xp(args.args[1]))
# The sent_ids and head_ids are None because no gradients
return [[gradients], None]
return ([gradients], None)
word_features = xp2torch(tok2vec[0], requires_grad=is_train)
sent_ids_tensor = xp2torch(sent_ids[0], requires_grad=False)
@ -207,9 +209,7 @@ class SpanPredictor(torch.nn.Module):
dim=1,
)
lengths = same_sent.sum(dim=1)
padding_mask = torch.arange(
0, lengths.max().item(), device=device
).unsqueeze(0)
padding_mask = torch.arange(0, lengths.max().item(), device=device).unsqueeze(0)
# (n_heads x max_sent_len)
padding_mask = padding_mask < lengths.unsqueeze(1)
# (n_heads x max_sent_len x input_size * 2 + distance_emb_size)