Formatting

This commit is contained in:
Paul O'Leary McCann 2022-03-16 20:10:47 +09:00
parent 6855df0e66
commit 1a79d18796

View File

@ -60,6 +60,7 @@ DEFAULT_MODEL = Config().from_str(default_config)["model"]
DEFAULT_CLUSTERS_PREFIX = "coref_clusters" DEFAULT_CLUSTERS_PREFIX = "coref_clusters"
@Language.factory( @Language.factory(
"coref", "coref",
assigns=["doc.spans"], assigns=["doc.spans"],
@ -133,7 +134,7 @@ class CoreferenceResolver(TrainablePipe):
# each item in scores includes scores and a mapping from scores to mentions # each item in scores includes scores and a mapping from scores to mentions
ant_idxs = idxs ant_idxs = idxs
#TODO batching # TODO batching
xp = self.model.ops.xp xp = self.model.ops.xp
starts = xp.arange(0, len(docs[0])) starts = xp.arange(0, len(docs[0]))
@ -242,7 +243,7 @@ class CoreferenceResolver(TrainablePipe):
if self._rehearsal_model is None: if self._rehearsal_model is None:
return losses return losses
validate_examples(examples, "CoreferenceResolver.rehearse") validate_examples(examples, "CoreferenceResolver.rehearse")
#TODO test this whole function # TODO test this whole function
docs = [eg.predicted for eg in examples] docs = [eg.predicted for eg in examples]
if not any(len(doc) for doc in docs): if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs. # Handle cases where there are no tokens in any docs.
@ -256,7 +257,7 @@ class CoreferenceResolver(TrainablePipe):
if sgd is not None: if sgd is not None:
self.finish_update(sgd) self.finish_update(sgd)
if losses is not None: if losses is not None:
losses[self.name] += (gradient ** 2).sum() losses[self.name] += (gradient**2).sum()
return losses return losses
def add_label(self, label: str) -> int: def add_label(self, label: str) -> int:
@ -290,13 +291,13 @@ class CoreferenceResolver(TrainablePipe):
offset = 0 offset = 0
gradients = [] gradients = []
total_loss = 0 total_loss = 0
#TODO change this # TODO change this
# 1. do not handle batching (add it back later) # 1. do not handle batching (add it back later)
# 2. don't do index conversion (no mentions, just word indices) # 2. don't do index conversion (no mentions, just word indices)
# 3. convert words to spans (if necessary) in gold and predictions # 3. convert words to spans (if necessary) in gold and predictions
# massage score matrix to be shaped correctly # massage score matrix to be shaped correctly
score_matrix = [ (score_matrix, None) ] score_matrix = [(score_matrix, None)]
for example, (cscores, cidx) in zip(examples, score_matrix): for example, (cscores, cidx) in zip(examples, score_matrix):
ll = cscores.shape[0] ll = cscores.shape[0]
@ -306,7 +307,7 @@ class CoreferenceResolver(TrainablePipe):
span_idxs = create_head_span_idxs(ops, len(example.predicted)) span_idxs = create_head_span_idxs(ops, len(example.predicted))
gscores = create_gold_scores(span_idxs, clusters) gscores = create_gold_scores(span_idxs, clusters)
gscores = ops.asarray2f(gscores) gscores = ops.asarray2f(gscores)
#top_gscores = xp.take_along_axis(gscores, cidx, axis=1) # top_gscores = xp.take_along_axis(gscores, cidx, axis=1)
top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1) top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1)
# now add the placeholder # now add the placeholder
gold_placeholder = ~top_gscores.any(axis=1).T gold_placeholder = ~top_gscores.any(axis=1).T
@ -322,7 +323,7 @@ class CoreferenceResolver(TrainablePipe):
log_norm = ops.softmax(cscores, axis=1) log_norm = ops.softmax(cscores, axis=1)
grad = log_norm - log_marg grad = log_norm - log_marg
gradients.append((grad, cidx)) gradients.append((grad, cidx))
total_loss += float((grad ** 2).sum()) total_loss += float((grad**2).sum())
offset = hi offset = hi
@ -389,6 +390,7 @@ class CoreferenceResolver(TrainablePipe):
out[fname] = mean([ss[fname] for ss in scores]) out[fname] = mean([ss[fname] for ss in scores])
return out return out
class SpanPredictor(TrainablePipe): class SpanPredictor(TrainablePipe):
"""Pipeline component to resolve one-token spans to full spans. """Pipeline component to resolve one-token spans to full spans.
@ -451,7 +453,7 @@ class SpanPredictor(TrainablePipe):
def get_loss( def get_loss(
self, self,
examples: Iterable[Example], examples: Iterable[Example],
#TODO add necessary args # TODO add necessary args
): ):
... ...
@ -475,4 +477,3 @@ class SpanPredictor(TrainablePipe):
def score(self, examples, **kwargs): def score(self, examples, **kwargs):
# TODO this will overlap significantly with coref, maybe factor into function # TODO this will overlap significantly with coref, maybe factor into function
... ...