mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-19 04:32:32 +03:00
Run black
This commit is contained in:
parent
bc081c24fa
commit
8bd0474730
|
@ -24,7 +24,7 @@ def build_coref(
|
|||
# the token count.
|
||||
mention_limit_ratio: float = 0.4,
|
||||
max_span_width: int = 20,
|
||||
antecedent_limit: int = 50
|
||||
antecedent_limit: int = 50,
|
||||
):
|
||||
dim = tok2vec.get_dim("nO") * 3
|
||||
|
||||
|
@ -90,15 +90,14 @@ def build_width_scorer(max_span_width, hidden_size, feature_embed_size=20):
|
|||
>> Linear(nI=hidden_size, nO=1)
|
||||
)
|
||||
span_width_prior.initialize()
|
||||
model = Model(
|
||||
"WidthScorer",
|
||||
forward=width_score_forward,
|
||||
layers=[span_width_prior])
|
||||
model = Model("WidthScorer", forward=width_score_forward, layers=[span_width_prior])
|
||||
model.set_ref("width_prior", span_width_prior)
|
||||
return model
|
||||
|
||||
|
||||
def width_score_forward(model, embeds: SpanEmbeddings, is_train) -> Tuple[Floats1d, Callable]:
|
||||
def width_score_forward(
|
||||
model, embeds: SpanEmbeddings, is_train
|
||||
) -> Tuple[Floats1d, Callable]:
|
||||
# calculate widths, subtracting 1 so it's 0-index
|
||||
w_ffnn = model.get_ref("width_prior")
|
||||
idxs = embeds.indices
|
||||
|
@ -115,6 +114,7 @@ def width_score_forward(model, embeds: SpanEmbeddings, is_train) -> Tuple[Floats
|
|||
|
||||
return wscores, width_score_backward
|
||||
|
||||
|
||||
# model converting a Doc/Mention to span embeddings
|
||||
# get_mentions: Callable[Doc, Pairs[int]]
|
||||
def build_span_embedder(
|
||||
|
@ -123,8 +123,9 @@ def build_span_embedder(
|
|||
) -> Model[Tuple[List[Floats2d], List[Doc]], SpanEmbeddings]:
|
||||
|
||||
with Model.define_operators({">>": chain, "|": concatenate}):
|
||||
span_reduce = (extract_spans() >>
|
||||
(reduce_first() | reduce_last() | reduce_mean()))
|
||||
span_reduce = extract_spans() >> (
|
||||
reduce_first() | reduce_last() | reduce_mean()
|
||||
)
|
||||
model = Model(
|
||||
"SpanEmbedding",
|
||||
forward=span_embeddings_forward,
|
||||
|
@ -290,6 +291,7 @@ def build_take_vecs() -> Model[SpanEmbeddings, Floats2d]:
|
|||
def take_vecs_forward(model, inputs: SpanEmbeddings, is_train) -> Floats2d:
|
||||
idxs = inputs.indices
|
||||
lens = inputs.vectors.lengths
|
||||
|
||||
def backprop(dY: Floats2d) -> SpanEmbeddings:
|
||||
vecs = Ragged(dY, lens)
|
||||
return SpanEmbeddings(idxs, vecs)
|
||||
|
@ -350,10 +352,10 @@ def ant_scorer_forward(
|
|||
# This will take the log of 0, which causes a warning, but we're doing
|
||||
# it on purpose so we can just ignore the warning.
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=RuntimeWarning)
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
mask = xp.log(
|
||||
(xp.expand_dims(ant_range, 1) - xp.expand_dims(ant_range, 0)) >= 1
|
||||
).astype('f')
|
||||
).astype("f")
|
||||
|
||||
scores = pw_prod + pw_sum + mask
|
||||
|
||||
|
|
|
@ -79,7 +79,9 @@ def make_coref(
|
|||
) -> "CoreferenceResolver":
|
||||
"""Create a CoreferenceResolver component."""
|
||||
|
||||
return CoreferenceResolver(nlp.vocab, model, name, span_cluster_prefix=span_cluster_prefix)
|
||||
return CoreferenceResolver(
|
||||
nlp.vocab, model, name, span_cluster_prefix=span_cluster_prefix
|
||||
)
|
||||
|
||||
|
||||
class CoreferenceResolver(TrainablePipe):
|
||||
|
@ -308,7 +310,7 @@ class CoreferenceResolver(TrainablePipe):
|
|||
top_gscores = ops.asarray2f(top_gscores)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=RuntimeWarning)
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
log_marg = ops.softmax(cscores + ops.xp.log(top_gscores), axis=1)
|
||||
log_norm = ops.softmax(cscores, axis=1)
|
||||
grad = log_norm - log_marg
|
||||
|
|
Loading…
Reference in New Issue
Block a user