mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	fix types + black formatting
This commit is contained in:
		
							parent
							
								
									3fee6933d7
								
							
						
					
					
						commit
						cea40c9d7b
					
				| 
						 | 
					@ -56,6 +56,7 @@ def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bo
 | 
				
			||||||
    def backprop(args: ArgsKwargs) -> List[Floats2d]:
 | 
					    def backprop(args: ArgsKwargs) -> List[Floats2d]:
 | 
				
			||||||
        # convert to xp and wrap in list
 | 
					        # convert to xp and wrap in list
 | 
				
			||||||
        gradients = torch2xp(args.args[0])
 | 
					        gradients = torch2xp(args.args[0])
 | 
				
			||||||
 | 
					        assert isinstance(gradients, Floats2d)
 | 
				
			||||||
        return [gradients]
 | 
					        return [gradients]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return ArgsKwargs(args=(word_features,), kwargs={}), backprop
 | 
					    return ArgsKwargs(args=(word_features,), kwargs={}), backprop
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,6 +1,6 @@
 | 
				
			||||||
 | 
					from typing import List, Tuple, Set, Dict, cast
 | 
				
			||||||
from thinc.types import Ints2d
 | 
					from thinc.types import Ints2d
 | 
				
			||||||
from spacy.tokens import Doc
 | 
					from spacy.tokens import Doc
 | 
				
			||||||
from typing import List, Tuple, Set
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
# type alias to make writing this less tedious
 | 
					# type alias to make writing this less tedious
 | 
				
			||||||
MentionClusters = List[List[Tuple[int, int]]]
 | 
					MentionClusters = List[List[Tuple[int, int]]]
 | 
				
			||||||
| 
						 | 
					@ -111,9 +111,9 @@ def select_non_crossing_spans(
 | 
				
			||||||
    Nested spans are allowed.
 | 
					    Nested spans are allowed.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    # ported from Model._extract_top_spans
 | 
					    # ported from Model._extract_top_spans
 | 
				
			||||||
    selected = []
 | 
					    selected: List[int] = []
 | 
				
			||||||
    start_to_max_end = {}
 | 
					    start_to_max_end: Dict[int, int] = {}
 | 
				
			||||||
    end_to_min_start = {}
 | 
					    end_to_min_start: Dict[int, int] = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for idx in idxs:
 | 
					    for idx in idxs:
 | 
				
			||||||
        if len(selected) >= limit or idx > len(starts):
 | 
					        if len(selected) >= limit or idx > len(starts):
 | 
				
			||||||
| 
						 | 
					@ -188,7 +188,7 @@ def create_gold_scores(
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    # make a mapping of mentions to cluster id
 | 
					    # make a mapping of mentions to cluster id
 | 
				
			||||||
    # id is not important but equality will be
 | 
					    # id is not important but equality will be
 | 
				
			||||||
    ment2cid = {}
 | 
					    ment2cid: Dict[Tuple[int, int], int] = {}
 | 
				
			||||||
    for cid, cluster in enumerate(clusters):
 | 
					    for cid, cluster in enumerate(clusters):
 | 
				
			||||||
        for ment in cluster:
 | 
					        for ment in cluster:
 | 
				
			||||||
            ment2cid[ment] = cid
 | 
					            ment2cid[ment] = cid
 | 
				
			||||||
| 
						 | 
					@ -196,7 +196,7 @@ def create_gold_scores(
 | 
				
			||||||
    ll = len(ments)
 | 
					    ll = len(ments)
 | 
				
			||||||
    out = []
 | 
					    out = []
 | 
				
			||||||
    # The .tolist() call is necessary with cupy but not numpy
 | 
					    # The .tolist() call is necessary with cupy but not numpy
 | 
				
			||||||
    mentuples = [tuple(mm.tolist()) for mm in ments]
 | 
					    mentuples = [cast(Tuple[int, int], tuple(mm.tolist())) for mm in ments]
 | 
				
			||||||
    for ii, ment in enumerate(mentuples):
 | 
					    for ii, ment in enumerate(mentuples):
 | 
				
			||||||
        if ment not in ment2cid:
 | 
					        if ment not in ment2cid:
 | 
				
			||||||
            # this is not in a cluster so it has no antecedent
 | 
					            # this is not in a cluster so it has no antecedent
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -259,7 +259,7 @@ class CoreferenceResolver(TrainablePipe):
 | 
				
			||||||
        span_idxs = create_head_span_idxs(ops, len(example.predicted))
 | 
					        span_idxs = create_head_span_idxs(ops, len(example.predicted))
 | 
				
			||||||
        gscores = create_gold_scores(span_idxs, clusters)
 | 
					        gscores = create_gold_scores(span_idxs, clusters)
 | 
				
			||||||
        # TODO fix type here. This is bools but asarray2f wants ints.
 | 
					        # TODO fix type here. This is bools but asarray2f wants ints.
 | 
				
			||||||
        gscores = ops.asarray2f(gscores)
 | 
					        gscores = ops.asarray2f(gscores)  # type: ignore
 | 
				
			||||||
        # top_gscores = xp.take_along_axis(gscores, cidx, axis=1)
 | 
					        # top_gscores = xp.take_along_axis(gscores, cidx, axis=1)
 | 
				
			||||||
        top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1)
 | 
					        top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1)
 | 
				
			||||||
        # now add the placeholder
 | 
					        # now add the placeholder
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user