report start- and end-accuracies separately

This commit is contained in:
Kádár Ákos 2022-04-08 14:57:19 +02:00
parent 2a1ad4c5d2
commit 7a239f2ec7

View File

@ -510,13 +510,11 @@ class SpanPredictor(TrainablePipe):
total_loss = 0 total_loss = 0
for eg in examples: for eg in examples:
# For update we use the gold coref_head_clusters
# in the reference.
span_scores, backprop = self.model.begin_update([eg.predicted]) span_scores, backprop = self.model.begin_update([eg.predicted])
loss, d_scores = self.get_loss([eg], span_scores) loss, d_scores = self.get_loss([eg], span_scores)
total_loss += loss total_loss += loss
# TODO check shape here # TODO check shape here
backprop(d_scores) backprop((d_scores))
if sgd is not None: if sgd is not None:
self.finish_update(sgd) self.finish_update(sgd)
@ -612,7 +610,8 @@ class SpanPredictor(TrainablePipe):
Evaluate on reconstructing the correct spans around Evaluate on reconstructing the correct spans around
gold heads. gold heads.
""" """
scores = [] start_scores = []
end_scores = []
for eg in examples: for eg in examples:
starts = [] starts = []
ends = [] ends = []
@ -622,7 +621,6 @@ class SpanPredictor(TrainablePipe):
pred = eg.predicted pred = eg.predicted
for key, gold_sg in ref.spans.items(): for key, gold_sg in ref.spans.items():
if key.startswith(self.output_prefix): if key.startswith(self.output_prefix):
cluster_id = key.split('_')[-1]
pred_sg = pred.spans[key] pred_sg = pred.spans[key]
for gold_mention, pred_mention in zip(gold_sg, pred_sg): for gold_mention, pred_mention in zip(gold_sg, pred_sg):
starts.append(gold_mention.start) starts.append(gold_mention.start)
@ -634,8 +632,12 @@ class SpanPredictor(TrainablePipe):
ends = self.model.ops.xp.asarray(ends) ends = self.model.ops.xp.asarray(ends)
pred_starts = self.model.ops.xp.asarray(pred_starts) pred_starts = self.model.ops.xp.asarray(pred_starts)
pred_ends = self.model.ops.xp.asarray(pred_ends) pred_ends = self.model.ops.xp.asarray(pred_ends)
correct = ((starts == pred_starts) * (ends == pred_ends)).sum() start_accuracy = (starts == pred_starts).mean()
scores.append(correct) end_accuracy = (ends == pred_ends).mean()
start_scores.append(float(start_accuracy))
out = {"span_accuracy": mean(scores)} end_scores.append(float(end_accuracy))
out = {
"span_start_accuracy": mean(start_scores),
"span_end_accuracy": mean(end_scores)
}
return out return out