mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 20:22:25 +03:00
fix types + black formatting
This commit is contained in:
parent
3fee6933d7
commit
cea40c9d7b
|
@ -56,6 +56,7 @@ def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bo
|
||||||
def backprop(args: ArgsKwargs) -> List[Floats2d]:
|
def backprop(args: ArgsKwargs) -> List[Floats2d]:
|
||||||
# convert to xp and wrap in list
|
# convert to xp and wrap in list
|
||||||
gradients = torch2xp(args.args[0])
|
gradients = torch2xp(args.args[0])
|
||||||
|
assert isinstance(gradients, Floats2d)
|
||||||
return [gradients]
|
return [gradients]
|
||||||
|
|
||||||
return ArgsKwargs(args=(word_features,), kwargs={}), backprop
|
return ArgsKwargs(args=(word_features,), kwargs={}), backprop
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
|
from typing import List, Tuple, Set, Dict, cast
|
||||||
from thinc.types import Ints2d
|
from thinc.types import Ints2d
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
from typing import List, Tuple, Set
|
|
||||||
|
|
||||||
# type alias to make writing this less tedious
|
# type alias to make writing this less tedious
|
||||||
MentionClusters = List[List[Tuple[int, int]]]
|
MentionClusters = List[List[Tuple[int, int]]]
|
||||||
|
@ -111,9 +111,9 @@ def select_non_crossing_spans(
|
||||||
Nested spans are allowed.
|
Nested spans are allowed.
|
||||||
"""
|
"""
|
||||||
# ported from Model._extract_top_spans
|
# ported from Model._extract_top_spans
|
||||||
selected = []
|
selected: List[int] = []
|
||||||
start_to_max_end = {}
|
start_to_max_end: Dict[int, int] = {}
|
||||||
end_to_min_start = {}
|
end_to_min_start: Dict[int, int] = {}
|
||||||
|
|
||||||
for idx in idxs:
|
for idx in idxs:
|
||||||
if len(selected) >= limit or idx > len(starts):
|
if len(selected) >= limit or idx > len(starts):
|
||||||
|
@ -188,7 +188,7 @@ def create_gold_scores(
|
||||||
"""
|
"""
|
||||||
# make a mapping of mentions to cluster id
|
# make a mapping of mentions to cluster id
|
||||||
# id is not important but equality will be
|
# id is not important but equality will be
|
||||||
ment2cid = {}
|
ment2cid: Dict[Tuple[int, int], int] = {}
|
||||||
for cid, cluster in enumerate(clusters):
|
for cid, cluster in enumerate(clusters):
|
||||||
for ment in cluster:
|
for ment in cluster:
|
||||||
ment2cid[ment] = cid
|
ment2cid[ment] = cid
|
||||||
|
@ -196,7 +196,7 @@ def create_gold_scores(
|
||||||
ll = len(ments)
|
ll = len(ments)
|
||||||
out = []
|
out = []
|
||||||
# The .tolist() call is necessary with cupy but not numpy
|
# The .tolist() call is necessary with cupy but not numpy
|
||||||
mentuples = [tuple(mm.tolist()) for mm in ments]
|
mentuples = [cast(Tuple[int, int], tuple(mm.tolist())) for mm in ments]
|
||||||
for ii, ment in enumerate(mentuples):
|
for ii, ment in enumerate(mentuples):
|
||||||
if ment not in ment2cid:
|
if ment not in ment2cid:
|
||||||
# this is not in a cluster so it has no antecedent
|
# this is not in a cluster so it has no antecedent
|
||||||
|
|
|
@ -259,7 +259,7 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
span_idxs = create_head_span_idxs(ops, len(example.predicted))
|
span_idxs = create_head_span_idxs(ops, len(example.predicted))
|
||||||
gscores = create_gold_scores(span_idxs, clusters)
|
gscores = create_gold_scores(span_idxs, clusters)
|
||||||
# TODO fix type here. This is bools but asarray2f wants ints.
|
# TODO fix type here. This is bools but asarray2f wants ints.
|
||||||
gscores = ops.asarray2f(gscores)
|
gscores = ops.asarray2f(gscores) # type: ignore
|
||||||
# top_gscores = xp.take_along_axis(gscores, cidx, axis=1)
|
# top_gscores = xp.take_along_axis(gscores, cidx, axis=1)
|
||||||
top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1)
|
top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1)
|
||||||
# now add the placeholder
|
# now add the placeholder
|
||||||
|
|
Loading…
Reference in New Issue
Block a user