span accuracy score

This commit is contained in:
Kádár Ákos 2022-04-04 18:10:09 +02:00
parent a1d0219903
commit ef141ad399

View File

@ -457,6 +457,7 @@ class SpanPredictor(TrainablePipe):
for doc in docs:
# TODO check shape here
span_scores = self.model.predict([doc])
print(span_scores)
if span_scores.size:
# the information about clustering has to come from the input docs
# first let's convert the scores to a list of span idxs
@ -608,30 +609,35 @@ class SpanPredictor(TrainablePipe):
self.model.initialize(X=X, Y=Y)
def score(self, examples, **kwargs):
"""Score a batch of examples."""
# TODO This is basically the same as the main coref component - factor out?
"""
Evaluate on reconstructing the correct spans around
gold heads.
"""
scores = []
for metric in (b_cubed, muc, ceafe):
evaluator = Evaluator(metric)
for eg in examples:
starts = []
ends = []
pred_starts = []
pred_ends = []
ref = eg.reference
pred = eg.predicted
for key, gold_sg in ref.spans.items():
if key.startswith(self.input_prefix):
cluster_id = key.split('_')[-1]
# FIXME THIS DOESN'T WORK BECAUSE pred.spans are empty?
pred_sg = pred.spans[f"{self.output_prefix}_{cluster_id}"]
for gold_mention, pred_mention in zip(gold_sg, pred_sg):
starts.append(gold_mention.start)
ends.append(gold_mention.end)
pred_starts.append(pred_mention.start)
pred_ends.append(pred_mention.end)
for ex in examples:
# XXX this is the only different part
p_clusters = doc2clusters(ex.predicted, self.output_prefix)
g_clusters = doc2clusters(ex.reference, self.output_prefix)
cluster_info = get_cluster_info(p_clusters, g_clusters)
starts = self.model.ops.xp.asarray(starts)
ends = self.model.ops.xp.asarray(ends)
pred_starts = self.model.ops.xp.asarray(pred_starts)
pred_ends = self.model.ops.xp.asarray(pred_ends)
correct = ((starts == pred_starts) * (ends == pred_ends)).sum()
scores.append(correct)
evaluator.update(cluster_info)
score = {
"coref_span_f": evaluator.get_f1(),
"coref_span_p": evaluator.get_precision(),
"coref_span_r": evaluator.get_recall(),
}
scores.append(score)
out = {}
for field in ("f", "p", "r"):
fname = f"coref_span_{field}"
out[fname] = mean([ss[fname] for ss in scores])
out = {"span_accuracy": mean(scores)}
return out