Run black

This commit is contained in:
Paul O'Leary McCann 2021-07-18 20:20:22 +09:00
parent bc081c24fa
commit 8bd0474730
3 changed files with 33 additions and 29 deletions

View File

@ -20,11 +20,11 @@ def build_coref(
hidden: int = 1000, hidden: int = 1000,
dropout: float = 0.3, dropout: float = 0.3,
mention_limit: int = 3900, mention_limit: int = 3900,
#TODO this needs a better name. It limits the max mentions as a ratio of # TODO this needs a better name. It limits the max mentions as a ratio of
# the token count. # the token count.
mention_limit_ratio: float = 0.4, mention_limit_ratio: float = 0.4,
max_span_width: int = 20, max_span_width: int = 20,
antecedent_limit: int = 50 antecedent_limit: int = 50,
): ):
dim = tok2vec.get_dim("nO") * 3 dim = tok2vec.get_dim("nO") * 3
@ -40,7 +40,7 @@ def build_coref(
) )
mention_scorer.initialize() mention_scorer.initialize()
#TODO make feature_embed_size a param # TODO make feature_embed_size a param
feature_embed_size = 20 feature_embed_size = 20
width_scorer = build_width_scorer(max_span_width, hidden, feature_embed_size) width_scorer = build_width_scorer(max_span_width, hidden, feature_embed_size)
@ -90,19 +90,18 @@ def build_width_scorer(max_span_width, hidden_size, feature_embed_size=20):
>> Linear(nI=hidden_size, nO=1) >> Linear(nI=hidden_size, nO=1)
) )
span_width_prior.initialize() span_width_prior.initialize()
model = Model( model = Model("WidthScorer", forward=width_score_forward, layers=[span_width_prior])
"WidthScorer",
forward=width_score_forward,
layers=[span_width_prior])
model.set_ref("width_prior", span_width_prior) model.set_ref("width_prior", span_width_prior)
return model return model
def width_score_forward(model, embeds: SpanEmbeddings, is_train) -> Tuple[Floats1d, Callable]: def width_score_forward(
model, embeds: SpanEmbeddings, is_train
) -> Tuple[Floats1d, Callable]:
# calculate widths, subtracting 1 so it's 0-index # calculate widths, subtracting 1 so it's 0-index
w_ffnn = model.get_ref("width_prior") w_ffnn = model.get_ref("width_prior")
idxs = embeds.indices idxs = embeds.indices
widths = idxs[:,1] - idxs[:,0] - 1 widths = idxs[:, 1] - idxs[:, 0] - 1
wscores, width_b = w_ffnn(widths, is_train) wscores, width_b = w_ffnn(widths, is_train)
lens = embeds.vectors.lengths lens = embeds.vectors.lengths
@ -115,6 +114,7 @@ def width_score_forward(model, embeds: SpanEmbeddings, is_train) -> Tuple[Floats
return wscores, width_score_backward return wscores, width_score_backward
# model converting a Doc/Mention to span embeddings # model converting a Doc/Mention to span embeddings
# get_mentions: Callable[Doc, Pairs[int]] # get_mentions: Callable[Doc, Pairs[int]]
def build_span_embedder( def build_span_embedder(
@ -123,8 +123,9 @@ def build_span_embedder(
) -> Model[Tuple[List[Floats2d], List[Doc]], SpanEmbeddings]: ) -> Model[Tuple[List[Floats2d], List[Doc]], SpanEmbeddings]:
with Model.define_operators({">>": chain, "|": concatenate}): with Model.define_operators({">>": chain, "|": concatenate}):
span_reduce = (extract_spans() >> span_reduce = extract_spans() >> (
(reduce_first() | reduce_last() | reduce_mean())) reduce_first() | reduce_last() | reduce_mean()
)
model = Model( model = Model(
"SpanEmbedding", "SpanEmbedding",
forward=span_embeddings_forward, forward=span_embeddings_forward,
@ -161,7 +162,7 @@ def span_embeddings_forward(
docmenlens.append(len(starts)) docmenlens.append(len(starts))
cments = ops.asarray2i([starts, ends]).transpose() cments = ops.asarray2i([starts, ends]).transpose()
mentions = xp.concatenate( (mentions, cments) ) mentions = xp.concatenate((mentions, cments))
# TODO support attention here # TODO support attention here
tokvecs = xp.concatenate(tokvecs) tokvecs = xp.concatenate(tokvecs)
@ -170,7 +171,7 @@ def span_embeddings_forward(
mentions_r = Ragged(mentions, docmenlens) mentions_r = Ragged(mentions, docmenlens)
span_reduce = model.get_ref("span_reducer") span_reduce = model.get_ref("span_reducer")
spanvecs, span_reduce_back = span_reduce( (tokvecs_r, mentions_r), is_train) spanvecs, span_reduce_back = span_reduce((tokvecs_r, mentions_r), is_train)
embeds = Ragged(spanvecs, docmenlens) embeds = Ragged(spanvecs, docmenlens)
@ -236,7 +237,7 @@ def coarse_prune(
# calculate the doc length # calculate the doc length
doclen = ends[-1] - starts[0] doclen = ends[-1] - starts[0]
# XXX seems to make more sense to use menlen than doclen here? # XXX seems to make more sense to use menlen than doclen here?
#mlimit = min(mention_limit, int(mention_limit_ratio * doclen)) # mlimit = min(mention_limit, int(mention_limit_ratio * doclen))
mlimit = min(mention_limit, int(mention_limit_ratio * menlen)) mlimit = min(mention_limit, int(mention_limit_ratio * menlen))
# csel is a 1d integer list # csel is a 1d integer list
csel = select_non_crossing_spans(tops, starts, ends, mlimit) csel = select_non_crossing_spans(tops, starts, ends, mlimit)
@ -290,6 +291,7 @@ def build_take_vecs() -> Model[SpanEmbeddings, Floats2d]:
def take_vecs_forward(model, inputs: SpanEmbeddings, is_train) -> Floats2d: def take_vecs_forward(model, inputs: SpanEmbeddings, is_train) -> Floats2d:
idxs = inputs.indices idxs = inputs.indices
lens = inputs.vectors.lengths lens = inputs.vectors.lengths
def backprop(dY: Floats2d) -> SpanEmbeddings: def backprop(dY: Floats2d) -> SpanEmbeddings:
vecs = Ragged(dY, lens) vecs = Ragged(dY, lens)
return SpanEmbeddings(idxs, vecs) return SpanEmbeddings(idxs, vecs)
@ -350,10 +352,10 @@ def ant_scorer_forward(
# This will take the log of 0, which causes a warning, but we're doing # This will take the log of 0, which causes a warning, but we're doing
# it on purpose so we can just ignore the warning. # it on purpose so we can just ignore the warning.
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=RuntimeWarning) warnings.filterwarnings("ignore", category=RuntimeWarning)
mask = xp.log( mask = xp.log(
(xp.expand_dims(ant_range, 1) - xp.expand_dims(ant_range, 0)) >= 1 (xp.expand_dims(ant_range, 1) - xp.expand_dims(ant_range, 0)) >= 1
).astype('f') ).astype("f")
scores = pw_prod + pw_sum + mask scores = pw_prod + pw_sum + mask
@ -361,7 +363,7 @@ def ant_scorer_forward(
top_scores, top_scores_idx = topk(xp, scores, top_limit) top_scores, top_scores_idx = topk(xp, scores, top_limit)
# now add the placeholder # now add the placeholder
placeholder = ops.alloc2f(scores.shape[0], 1) placeholder = ops.alloc2f(scores.shape[0], 1)
top_scores = xp.concatenate( (placeholder, top_scores), 1) top_scores = xp.concatenate((placeholder, top_scores), 1)
out.append((top_scores, top_scores_idx)) out.append((top_scores, top_scores_idx))
@ -398,8 +400,8 @@ def ant_scorer_forward(
for ii, (ridx, rscores) in enumerate(zip(dyidx, dyscore)): for ii, (ridx, rscores) in enumerate(zip(dyidx, dyscore)):
fullscore[ii][ridx] = rscores fullscore[ii][ridx] = rscores
dXembeds.data[offset : hi] = prod_back(fullscore) dXembeds.data[offset:hi] = prod_back(fullscore)
dXscores[offset : hi] = pw_sum_back(fullscore) dXscores[offset:hi] = pw_sum_back(fullscore)
offset = hi offset = hi
# make it fit back into the linear # make it fit back into the linear
@ -421,7 +423,7 @@ def pairwise_sum(ops, mention_scores: Floats1d) -> Tuple[Floats2d, Callable]:
def backward(d_pwsum: Floats2d) -> Floats1d: def backward(d_pwsum: Floats2d) -> Floats1d:
# For the backward pass, the gradient is distributed over the whole row and # For the backward pass, the gradient is distributed over the whole row and
# column, so pull it all in. # column, so pull it all in.
out = d_pwsum.sum(axis=0) + d_pwsum.sum(axis=1) out = d_pwsum.sum(axis=0) + d_pwsum.sum(axis=1)
return out return out

View File

@ -129,7 +129,7 @@ def get_candidate_mentions(
for ii in range(1, max_span_width): for ii in range(1, max_span_width):
ei = tok.i + ii # end index ei = tok.i + ii # end index
# Note: this matches slice syntax, so the token index is one less # Note: this matches slice syntax, so the token index is one less
if ei > len(doc) or sentence_map[ei-1] != si: if ei > len(doc) or sentence_map[ei - 1] != si:
continue continue
begins.append(tok.i) begins.append(tok.i)

View File

@ -79,7 +79,9 @@ def make_coref(
) -> "CoreferenceResolver": ) -> "CoreferenceResolver":
"""Create a CoreferenceResolver component.""" """Create a CoreferenceResolver component."""
return CoreferenceResolver(nlp.vocab, model, name, span_cluster_prefix=span_cluster_prefix) return CoreferenceResolver(
nlp.vocab, model, name, span_cluster_prefix=span_cluster_prefix
)
class CoreferenceResolver(TrainablePipe): class CoreferenceResolver(TrainablePipe):
@ -308,7 +310,7 @@ class CoreferenceResolver(TrainablePipe):
top_gscores = ops.asarray2f(top_gscores) top_gscores = ops.asarray2f(top_gscores)
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=RuntimeWarning) warnings.filterwarnings("ignore", category=RuntimeWarning)
log_marg = ops.softmax(cscores + ops.xp.log(top_gscores), axis=1) log_marg = ops.softmax(cscores + ops.xp.log(top_gscores), axis=1)
log_norm = ops.softmax(cscores, axis=1) log_norm = ops.softmax(cscores, axis=1)
grad = log_norm - log_marg grad = log_norm - log_marg
@ -351,7 +353,7 @@ class CoreferenceResolver(TrainablePipe):
def score(self, examples, **kwargs): def score(self, examples, **kwargs):
"""Score a batch of examples.""" """Score a batch of examples."""
#NOTE traditionally coref uses the average of b_cubed, muc, and ceaf. # NOTE traditionally coref uses the average of b_cubed, muc, and ceaf.
# we need to handle the average ourselves. # we need to handle the average ourselves.
scores = [] scores = []
for metric in (b_cubed, muc, ceafe): for metric in (b_cubed, muc, ceafe):
@ -365,11 +367,11 @@ class CoreferenceResolver(TrainablePipe):
evaluator.update(cluster_info) evaluator.update(cluster_info)
score ={ score = {
"coref_f": evaluator.get_f1(), "coref_f": evaluator.get_f1(),
"coref_p": evaluator.get_precision(), "coref_p": evaluator.get_precision(),
"coref_r": evaluator.get_recall(), "coref_r": evaluator.get_recall(),
} }
scores.append(score) scores.append(score)
out = {} out = {}