diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index d59619498..96fad8019 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -56,6 +56,7 @@ def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bo def backprop(args: ArgsKwargs) -> List[Floats2d]: # convert to xp and wrap in list gradients = torch2xp(args.args[0]) + assert isinstance(gradients, Floats2d) return [gradients] return ArgsKwargs(args=(word_features,), kwargs={}), backprop diff --git a/spacy/ml/models/coref_util.py b/spacy/ml/models/coref_util.py index 8d0ff7bb0..dc9366a61 100644 --- a/spacy/ml/models/coref_util.py +++ b/spacy/ml/models/coref_util.py @@ -1,6 +1,6 @@ +from typing import List, Tuple, Set, Dict, cast from thinc.types import Ints2d from spacy.tokens import Doc -from typing import List, Tuple, Set # type alias to make writing this less tedious MentionClusters = List[List[Tuple[int, int]]] @@ -111,9 +111,9 @@ def select_non_crossing_spans( Nested spans are allowed. """ # ported from Model._extract_top_spans - selected = [] - start_to_max_end = {} - end_to_min_start = {} + selected: List[int] = [] + start_to_max_end: Dict[int, int] = {} + end_to_min_start: Dict[int, int] = {} for idx in idxs: if len(selected) >= limit or idx > len(starts): @@ -188,7 +188,7 @@ def create_gold_scores( """ # make a mapping of mentions to cluster id # id is not important but equality will be - ment2cid = {} + ment2cid: Dict[Tuple[int, int], int] = {} for cid, cluster in enumerate(clusters): for ment in cluster: ment2cid[ment] = cid @@ -196,7 +196,7 @@ def create_gold_scores( ll = len(ments) out = [] # The .tolist() call is necessary with cupy but not numpy - mentuples = [tuple(mm.tolist()) for mm in ments] + mentuples = [cast(Tuple[int, int], tuple(mm.tolist())) for mm in ments] for ii, ment in enumerate(mentuples): if ment not in ment2cid: # this is not in a cluster so it has no antecedent diff --git a/spacy/pipeline/coref.py b/spacy/pipeline/coref.py index 76e790896..ebdb3b9d0 100644 --- a/spacy/pipeline/coref.py +++ b/spacy/pipeline/coref.py @@ -259,7 +259,7 @@ class CoreferenceResolver(TrainablePipe): span_idxs = create_head_span_idxs(ops, len(example.predicted)) gscores = create_gold_scores(span_idxs, clusters) # TODO fix type here. This is bools but asarray2f wants ints. - gscores = ops.asarray2f(gscores) + gscores = ops.asarray2f(gscores) # type: ignore # top_gscores = xp.take_along_axis(gscores, cidx, axis=1) top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1) # now add the placeholder