Fix types

mypy now exits without an error, except for two apparently unrelated
ones about setup.py.
This commit is contained in:
Paul O'Leary McCann 2022-07-08 18:29:14 +09:00
parent ce49136458
commit 2eee0d248e
2 changed files with 23 additions and 30 deletions

View File

@ -1,8 +1,8 @@
from typing import List, Tuple from typing import List, Tuple, Callable, cast
from thinc.api import Model, chain from thinc.api import Model, chain
from thinc.api import PyTorchWrapper, ArgsKwargs from thinc.api import PyTorchWrapper, ArgsKwargs
from thinc.types import Floats2d from thinc.types import Floats2d, Ints2d
from thinc.util import torch, xp2torch, torch2xp from thinc.util import torch, xp2torch, torch2xp
from ...tokens import Doc from ...tokens import Doc
@ -23,9 +23,7 @@ def build_wl_coref_model(
antecedent_limit: int = 50, antecedent_limit: int = 50,
antecedent_batch_size: int = 512, antecedent_batch_size: int = 512,
tok2vec_size: int = 768, # tok2vec size tok2vec_size: int = 768, # tok2vec size
): ) -> Model[List[Doc], Tuple[Floats2d, Ints2d]]:
# TODO add model return types
# dim = tok2vec.maybe_get_dim("n0")
with Model.define_operators({">>": chain}): with Model.define_operators({">>": chain}):
coref_clusterer = PyTorchWrapper( coref_clusterer = PyTorchWrapper(
@ -45,27 +43,24 @@ def build_wl_coref_model(
return coref_model return coref_model
def convert_coref_clusterer_inputs( def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool):
model: Model, X: List[Floats2d], is_train: bool
):
# The input here is List[Floats2d], one for each doc # The input here is List[Floats2d], one for each doc
# just use the first # just use the first
# TODO real batching # TODO real batching
X = X[0] X = X[0]
word_features = xp2torch(X, requires_grad=is_train) word_features = xp2torch(X, requires_grad=is_train)
# TODO fix or remove type annotations def backprop(args: ArgsKwargs) -> List[Floats2d]:
def backprop(args: ArgsKwargs): #-> List[Floats2d]:
# convert to xp and wrap in list # convert to xp and wrap in list
gradients = torch2xp(args.args[0]) gradients = cast(Floats2d, torch2xp(args.args[0]))
return [gradients] return [gradients]
return ArgsKwargs(args=(word_features,), kwargs={}), backprop return ArgsKwargs(args=(word_features,), kwargs={}), backprop
def convert_coref_clusterer_outputs( def convert_coref_clusterer_outputs(
model: Model, inputs_outputs, is_train: bool model: Model, inputs_outputs, is_train: bool
): ) -> Tuple[Tuple[Floats2d, Ints2d], Callable]:
_, outputs = inputs_outputs _, outputs = inputs_outputs
scores, indices = outputs scores, indices = outputs
@ -76,8 +71,8 @@ def convert_coref_clusterer_outputs(
kwargs={"grad_tensors": [dY_t]}, kwargs={"grad_tensors": [dY_t]},
) )
scores_xp = torch2xp(scores) scores_xp = cast(Floats2d, torch2xp(scores))
indices_xp = torch2xp(indices) indices_xp = cast(Ints2d, torch2xp(indices))
return (scores_xp, indices_xp), convert_for_torch_backward return (scores_xp, indices_xp), convert_for_torch_backward
@ -115,9 +110,7 @@ class CorefClusterer(torch.nn.Module):
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout) self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
pair_emb = dim * 3 + self.pw.shape pair_emb = dim * 3 + self.pw.shape
self.a_scorer = AnaphoricityScorer( self.a_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout)
pair_emb, hidden_size, n_layers, dropout
)
self.lstm = torch.nn.LSTM( self.lstm = torch.nn.LSTM(
input_size=dim, input_size=dim,
hidden_size=dim, hidden_size=dim,
@ -156,10 +149,10 @@ class CorefClusterer(torch.nn.Module):
a_scores_lst: List[torch.Tensor] = [] a_scores_lst: List[torch.Tensor] = []
for i in range(0, len(words), batch_size): for i in range(0, len(words), batch_size):
pw_batch = pw[i:i + batch_size] pw_batch = pw[i : i + batch_size]
words_batch = words[i:i + batch_size] words_batch = words[i : i + batch_size]
top_indices_batch = top_indices[i:i + batch_size] top_indices_batch = top_indices[i : i + batch_size]
top_rough_scores_batch = top_rough_scores[i:i + batch_size] top_rough_scores_batch = top_rough_scores[i : i + batch_size]
# a_scores_batch [batch_size, n_ants] # a_scores_batch [batch_size, n_ants]
a_scores_batch = self.a_scorer( a_scores_batch = self.a_scorer(

View File

@ -1,4 +1,4 @@
from typing import List, Tuple from typing import List, Tuple, cast
from thinc.api import Model, chain, tuplify from thinc.api import Model, chain, tuplify
from thinc.api import PyTorchWrapper, ArgsKwargs from thinc.api import PyTorchWrapper, ArgsKwargs
@ -42,15 +42,17 @@ def build_span_predictor(
def convert_span_predictor_inputs( def convert_span_predictor_inputs(
model: Model, X: Tuple[List[Floats2d], Tuple[List[Ints1d], List[Ints1d]]], is_train: bool model: Model,
X: Tuple[List[Floats2d], Tuple[List[Ints1d], List[Ints1d]]],
is_train: bool,
): ):
tok2vec, (sent_ids, head_ids) = X tok2vec, (sent_ids, head_ids) = X
# Normally we should use the input is_train, but for these two it's not relevant # Normally we should use the input is_train, but for these two it's not relevant
# TODO fix the type here, or remove it # TODO fix the type here, or remove it
def backprop(args: ArgsKwargs): #-> Tuple[List[Floats2d], None]: def backprop(args: ArgsKwargs) -> Tuple[List[Floats2d], None]:
gradients = torch2xp(args.args[1]) gradients = cast(Floats2d, torch2xp(args.args[1]))
# The sent_ids and head_ids are None because no gradients # The sent_ids and head_ids are None because no gradients
return [[gradients], None] return ([gradients], None)
word_features = xp2torch(tok2vec[0], requires_grad=is_train) word_features = xp2torch(tok2vec[0], requires_grad=is_train)
sent_ids_tensor = xp2torch(sent_ids[0], requires_grad=False) sent_ids_tensor = xp2torch(sent_ids[0], requires_grad=False)
@ -207,9 +209,7 @@ class SpanPredictor(torch.nn.Module):
dim=1, dim=1,
) )
lengths = same_sent.sum(dim=1) lengths = same_sent.sum(dim=1)
padding_mask = torch.arange( padding_mask = torch.arange(0, lengths.max().item(), device=device).unsqueeze(0)
0, lengths.max().item(), device=device
).unsqueeze(0)
# (n_heads x max_sent_len) # (n_heads x max_sent_len)
padding_mask = padding_mask < lengths.unsqueeze(1) padding_mask = padding_mask < lengths.unsqueeze(1)
# (n_heads x max_sent_len x input_size * 2 + distance_emb_size) # (n_heads x max_sent_len x input_size * 2 + distance_emb_size)