mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
report start- and end-accuracies separately
This commit is contained in:
parent
2a1ad4c5d2
commit
7a239f2ec7
|
@ -510,13 +510,11 @@ class SpanPredictor(TrainablePipe):
|
||||||
|
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
for eg in examples:
|
for eg in examples:
|
||||||
# For update we use the gold coref_head_clusters
|
|
||||||
# in the reference.
|
|
||||||
span_scores, backprop = self.model.begin_update([eg.predicted])
|
span_scores, backprop = self.model.begin_update([eg.predicted])
|
||||||
loss, d_scores = self.get_loss([eg], span_scores)
|
loss, d_scores = self.get_loss([eg], span_scores)
|
||||||
total_loss += loss
|
total_loss += loss
|
||||||
# TODO check shape here
|
# TODO check shape here
|
||||||
backprop(d_scores)
|
backprop((d_scores))
|
||||||
|
|
||||||
if sgd is not None:
|
if sgd is not None:
|
||||||
self.finish_update(sgd)
|
self.finish_update(sgd)
|
||||||
|
@ -612,7 +610,8 @@ class SpanPredictor(TrainablePipe):
|
||||||
Evaluate on reconstructing the correct spans around
|
Evaluate on reconstructing the correct spans around
|
||||||
gold heads.
|
gold heads.
|
||||||
"""
|
"""
|
||||||
scores = []
|
start_scores = []
|
||||||
|
end_scores = []
|
||||||
for eg in examples:
|
for eg in examples:
|
||||||
starts = []
|
starts = []
|
||||||
ends = []
|
ends = []
|
||||||
|
@ -622,7 +621,6 @@ class SpanPredictor(TrainablePipe):
|
||||||
pred = eg.predicted
|
pred = eg.predicted
|
||||||
for key, gold_sg in ref.spans.items():
|
for key, gold_sg in ref.spans.items():
|
||||||
if key.startswith(self.output_prefix):
|
if key.startswith(self.output_prefix):
|
||||||
cluster_id = key.split('_')[-1]
|
|
||||||
pred_sg = pred.spans[key]
|
pred_sg = pred.spans[key]
|
||||||
for gold_mention, pred_mention in zip(gold_sg, pred_sg):
|
for gold_mention, pred_mention in zip(gold_sg, pred_sg):
|
||||||
starts.append(gold_mention.start)
|
starts.append(gold_mention.start)
|
||||||
|
@ -634,8 +632,12 @@ class SpanPredictor(TrainablePipe):
|
||||||
ends = self.model.ops.xp.asarray(ends)
|
ends = self.model.ops.xp.asarray(ends)
|
||||||
pred_starts = self.model.ops.xp.asarray(pred_starts)
|
pred_starts = self.model.ops.xp.asarray(pred_starts)
|
||||||
pred_ends = self.model.ops.xp.asarray(pred_ends)
|
pred_ends = self.model.ops.xp.asarray(pred_ends)
|
||||||
correct = ((starts == pred_starts) * (ends == pred_ends)).sum()
|
start_accuracy = (starts == pred_starts).mean()
|
||||||
scores.append(correct)
|
end_accuracy = (ends == pred_ends).mean()
|
||||||
|
start_scores.append(float(start_accuracy))
|
||||||
out = {"span_accuracy": mean(scores)}
|
end_scores.append(float(end_accuracy))
|
||||||
|
out = {
|
||||||
|
"span_start_accuracy": mean(start_scores),
|
||||||
|
"span_end_accuracy": mean(end_scores)
|
||||||
|
}
|
||||||
return out
|
return out
|
||||||
|
|
Loading…
Reference in New Issue
Block a user