fixing scorer

This commit is contained in:
Kádár Ákos 2022-04-11 16:10:14 +02:00
parent 7a239f2ec7
commit 6aedd98d02

View File

@ -511,6 +511,9 @@ class SpanPredictor(TrainablePipe):
total_loss = 0 total_loss = 0
for eg in examples: for eg in examples:
span_scores, backprop = self.model.begin_update([eg.predicted]) span_scores, backprop = self.model.begin_update([eg.predicted])
# FIXME, this only happens once in the first 1000 docs of OntoNotes
# and I'm not sure yet why.
if span_scores.size:
loss, d_scores = self.get_loss([eg], span_scores) loss, d_scores = self.get_loss([eg], span_scores)
total_loss += loss total_loss += loss
# TODO check shape here # TODO check shape here
@ -557,7 +560,6 @@ class SpanPredictor(TrainablePipe):
assert len(examples) == 1, "Only fake batching is supported." assert len(examples) == 1, "Only fake batching is supported."
# starts and ends are gold starts and ends (Ints1d) # starts and ends are gold starts and ends (Ints1d)
# span_scores is a Floats3d. What are the axes? mention x token x start/end # span_scores is a Floats3d. What are the axes? mention x token x start/end
for eg in examples: for eg in examples:
starts = [] starts = []
ends = [] ends = []
@ -610,8 +612,8 @@ class SpanPredictor(TrainablePipe):
Evaluate on reconstructing the correct spans around Evaluate on reconstructing the correct spans around
gold heads. gold heads.
""" """
start_scores = [] scores = []
end_scores = [] xp = self.model.ops.xp
for eg in examples: for eg in examples:
starts = [] starts = []
ends = [] ends = []
@ -628,16 +630,11 @@ class SpanPredictor(TrainablePipe):
pred_starts.append(pred_mention.start) pred_starts.append(pred_mention.start)
pred_ends.append(pred_mention.end) pred_ends.append(pred_mention.end)
starts = self.model.ops.xp.asarray(starts) starts = xp.asarray(starts)
ends = self.model.ops.xp.asarray(ends) ends = xp.asarray(ends)
pred_starts = self.model.ops.xp.asarray(pred_starts) pred_starts = xp.asarray(pred_starts)
pred_ends = self.model.ops.xp.asarray(pred_ends) pred_ends = xp.asarray(pred_ends)
start_accuracy = (starts == pred_starts).mean() correct = (starts == pred_starts) * (ends == pred_ends)
end_accuracy = (ends == pred_ends).mean() accuracy = correct.mean()
start_scores.append(float(start_accuracy)) scores.append(float(accuracy))
end_scores.append(float(end_accuracy)) return {"span_accuracy": mean(scores)}
out = {
"span_start_accuracy": mean(start_scores),
"span_end_accuracy": mean(end_scores)
}
return out