diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index a69c6673e..8befeaef3 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -22,11 +22,11 @@ def build_wl_coref_model( # pairs to keep per mention after rough scoring antecedent_limit: int = 50, antecedent_batch_size: int = 512, - nI = None + nI=None, ) -> Model[List[Doc], Tuple[Floats2d, Ints2d]]: with Model.define_operators({">>": chain}): - coref_clusterer = Model( + coref_clusterer: Model[List[Floats2d], Tuple[Floats2d, Ints2d]] = Model( "coref_clusterer", forward=coref_forward, init=coref_init, @@ -81,6 +81,7 @@ def coref_init(model: Model, X=None, Y=None): def coref_forward(model: Model, X, is_train: bool): return model.layers[0](X, is_train) + def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool): # The input here is List[Floats2d], one for each doc # just use the first diff --git a/spacy/ml/models/span_predictor.py b/spacy/ml/models/span_predictor.py index ca76e5a4a..8a93dcf8e 100644 --- a/spacy/ml/models/span_predictor.py +++ b/spacy/ml/models/span_predictor.py @@ -25,7 +25,7 @@ def build_span_predictor( nI = None with Model.define_operators({">>": chain, "&": tuplify}): - span_predictor = Model( + span_predictor: Model[List[Floats2d], List[Floats2d]] = Model( "span_predictor", forward=span_predictor_forward, init=span_predictor_init, @@ -44,6 +44,7 @@ def build_span_predictor( return model + def span_predictor_init(model: Model, X=None, Y=None): if model.layers: return @@ -72,9 +73,11 @@ def span_predictor_init(model: Model, X=None, Y=None): # TODO maybe we need mixed precision and grad scaling? ] + def span_predictor_forward(model: Model, X, is_train: bool): return model.layers[0](X, is_train) + def convert_span_predictor_inputs( model: Model, X: Tuple[List[Floats2d], Tuple[List[Ints1d], List[Ints1d]]], @@ -95,7 +98,9 @@ def convert_span_predictor_inputs( else: head_ids_tensor = xp2torch(head_ids[0], requires_grad=False) - argskwargs = ArgsKwargs(args=(sent_ids_tensor, word_features, head_ids_tensor), kwargs={}) + argskwargs = ArgsKwargs( + args=(sent_ids_tensor, word_features, head_ids_tensor), kwargs={} + ) return argskwargs, backprop