diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index 660ef68c5..8b262ad3b 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -1,8 +1,8 @@ -from typing import List, Tuple +from typing import List, Tuple, Callable, cast from thinc.api import Model, chain from thinc.api import PyTorchWrapper, ArgsKwargs -from thinc.types import Floats2d +from thinc.types import Floats2d, Ints2d from thinc.util import torch, xp2torch, torch2xp from ...tokens import Doc @@ -23,9 +23,7 @@ def build_wl_coref_model( antecedent_limit: int = 50, antecedent_batch_size: int = 512, tok2vec_size: int = 768, # tok2vec size -): - # TODO add model return types - # dim = tok2vec.maybe_get_dim("n0") +) -> Model[List[Doc], Tuple[Floats2d, Ints2d]]: with Model.define_operators({">>": chain}): coref_clusterer = PyTorchWrapper( @@ -45,27 +43,24 @@ def build_wl_coref_model( return coref_model -def convert_coref_clusterer_inputs( - model: Model, X: List[Floats2d], is_train: bool -): +def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool): # The input here is List[Floats2d], one for each doc # just use the first # TODO real batching X = X[0] word_features = xp2torch(X, requires_grad=is_train) - # TODO fix or remove type annotations - def backprop(args: ArgsKwargs): #-> List[Floats2d]: + def backprop(args: ArgsKwargs) -> List[Floats2d]: # convert to xp and wrap in list - gradients = torch2xp(args.args[0]) + gradients = cast(Floats2d, torch2xp(args.args[0])) return [gradients] return ArgsKwargs(args=(word_features,), kwargs={}), backprop def convert_coref_clusterer_outputs( - model: Model, inputs_outputs, is_train: bool -): + model: Model, inputs_outputs, is_train: bool +) -> Tuple[Tuple[Floats2d, Ints2d], Callable]: _, outputs = inputs_outputs scores, indices = outputs @@ -76,8 +71,8 @@ def convert_coref_clusterer_outputs( kwargs={"grad_tensors": [dY_t]}, ) - scores_xp = torch2xp(scores) - indices_xp = torch2xp(indices) + scores_xp = cast(Floats2d, torch2xp(scores)) + indices_xp = cast(Ints2d, torch2xp(indices)) return (scores_xp, indices_xp), convert_for_torch_backward @@ -115,9 +110,7 @@ class CorefClusterer(torch.nn.Module): self.pw = DistancePairwiseEncoder(dist_emb_size, dropout) pair_emb = dim * 3 + self.pw.shape - self.a_scorer = AnaphoricityScorer( - pair_emb, hidden_size, n_layers, dropout - ) + self.a_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout) self.lstm = torch.nn.LSTM( input_size=dim, hidden_size=dim, @@ -156,10 +149,10 @@ class CorefClusterer(torch.nn.Module): a_scores_lst: List[torch.Tensor] = [] for i in range(0, len(words), batch_size): - pw_batch = pw[i:i + batch_size] - words_batch = words[i:i + batch_size] - top_indices_batch = top_indices[i:i + batch_size] - top_rough_scores_batch = top_rough_scores[i:i + batch_size] + pw_batch = pw[i : i + batch_size] + words_batch = words[i : i + batch_size] + top_indices_batch = top_indices[i : i + batch_size] + top_rough_scores_batch = top_rough_scores[i : i + batch_size] # a_scores_batch [batch_size, n_ants] a_scores_batch = self.a_scorer( diff --git a/spacy/ml/models/span_predictor.py b/spacy/ml/models/span_predictor.py index f11ecb5d5..1947b7833 100644 --- a/spacy/ml/models/span_predictor.py +++ b/spacy/ml/models/span_predictor.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List, Tuple, cast from thinc.api import Model, chain, tuplify from thinc.api import PyTorchWrapper, ArgsKwargs @@ -42,15 +42,17 @@ def build_span_predictor( def convert_span_predictor_inputs( - model: Model, X: Tuple[List[Floats2d], Tuple[List[Ints1d], List[Ints1d]]], is_train: bool + model: Model, + X: Tuple[List[Floats2d], Tuple[List[Ints1d], List[Ints1d]]], + is_train: bool, ): tok2vec, (sent_ids, head_ids) = X # Normally we should use the input is_train, but for these two it's not relevant # TODO fix the type here, or remove it - def backprop(args: ArgsKwargs): #-> Tuple[List[Floats2d], None]: - gradients = torch2xp(args.args[1]) + def backprop(args: ArgsKwargs) -> Tuple[List[Floats2d], None]: + gradients = cast(Floats2d, torch2xp(args.args[1])) # The sent_ids and head_ids are None because no gradients - return [[gradients], None] + return ([gradients], None) word_features = xp2torch(tok2vec[0], requires_grad=is_train) sent_ids_tensor = xp2torch(sent_ids[0], requires_grad=False) @@ -207,9 +209,7 @@ class SpanPredictor(torch.nn.Module): dim=1, ) lengths = same_sent.sum(dim=1) - padding_mask = torch.arange( - 0, lengths.max().item(), device=device - ).unsqueeze(0) + padding_mask = torch.arange(0, lengths.max().item(), device=device).unsqueeze(0) # (n_heads x max_sent_len) padding_mask = padding_mask < lengths.unsqueeze(1) # (n_heads x max_sent_len x input_size * 2 + distance_emb_size)