mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-19 20:52:23 +03:00
fixing scorer
This commit is contained in:
parent
7a239f2ec7
commit
6aedd98d02
|
@ -511,6 +511,9 @@ class SpanPredictor(TrainablePipe):
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
for eg in examples:
|
for eg in examples:
|
||||||
span_scores, backprop = self.model.begin_update([eg.predicted])
|
span_scores, backprop = self.model.begin_update([eg.predicted])
|
||||||
|
# FIXME, this only happens once in the first 1000 docs of OntoNotes
|
||||||
|
# and I'm not sure yet why.
|
||||||
|
if span_scores.size:
|
||||||
loss, d_scores = self.get_loss([eg], span_scores)
|
loss, d_scores = self.get_loss([eg], span_scores)
|
||||||
total_loss += loss
|
total_loss += loss
|
||||||
# TODO check shape here
|
# TODO check shape here
|
||||||
|
@ -557,7 +560,6 @@ class SpanPredictor(TrainablePipe):
|
||||||
assert len(examples) == 1, "Only fake batching is supported."
|
assert len(examples) == 1, "Only fake batching is supported."
|
||||||
# starts and ends are gold starts and ends (Ints1d)
|
# starts and ends are gold starts and ends (Ints1d)
|
||||||
# span_scores is a Floats3d. What are the axes? mention x token x start/end
|
# span_scores is a Floats3d. What are the axes? mention x token x start/end
|
||||||
|
|
||||||
for eg in examples:
|
for eg in examples:
|
||||||
starts = []
|
starts = []
|
||||||
ends = []
|
ends = []
|
||||||
|
@ -610,8 +612,8 @@ class SpanPredictor(TrainablePipe):
|
||||||
Evaluate on reconstructing the correct spans around
|
Evaluate on reconstructing the correct spans around
|
||||||
gold heads.
|
gold heads.
|
||||||
"""
|
"""
|
||||||
start_scores = []
|
scores = []
|
||||||
end_scores = []
|
xp = self.model.ops.xp
|
||||||
for eg in examples:
|
for eg in examples:
|
||||||
starts = []
|
starts = []
|
||||||
ends = []
|
ends = []
|
||||||
|
@ -628,16 +630,11 @@ class SpanPredictor(TrainablePipe):
|
||||||
pred_starts.append(pred_mention.start)
|
pred_starts.append(pred_mention.start)
|
||||||
pred_ends.append(pred_mention.end)
|
pred_ends.append(pred_mention.end)
|
||||||
|
|
||||||
starts = self.model.ops.xp.asarray(starts)
|
starts = xp.asarray(starts)
|
||||||
ends = self.model.ops.xp.asarray(ends)
|
ends = xp.asarray(ends)
|
||||||
pred_starts = self.model.ops.xp.asarray(pred_starts)
|
pred_starts = xp.asarray(pred_starts)
|
||||||
pred_ends = self.model.ops.xp.asarray(pred_ends)
|
pred_ends = xp.asarray(pred_ends)
|
||||||
start_accuracy = (starts == pred_starts).mean()
|
correct = (starts == pred_starts) * (ends == pred_ends)
|
||||||
end_accuracy = (ends == pred_ends).mean()
|
accuracy = correct.mean()
|
||||||
start_scores.append(float(start_accuracy))
|
scores.append(float(accuracy))
|
||||||
end_scores.append(float(end_accuracy))
|
return {"span_accuracy": mean(scores)}
|
||||||
out = {
|
|
||||||
"span_start_accuracy": mean(start_scores),
|
|
||||||
"span_end_accuracy": mean(end_scores)
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user