mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
span accuracy score
This commit is contained in:
parent
a1d0219903
commit
ef141ad399
|
@ -457,6 +457,7 @@ class SpanPredictor(TrainablePipe):
|
|||
for doc in docs:
|
||||
# TODO check shape here
|
||||
span_scores = self.model.predict([doc])
|
||||
print(span_scores)
|
||||
if span_scores.size:
|
||||
# the information about clustering has to come from the input docs
|
||||
# first let's convert the scores to a list of span idxs
|
||||
|
@ -608,30 +609,35 @@ class SpanPredictor(TrainablePipe):
|
|||
self.model.initialize(X=X, Y=Y)
|
||||
|
||||
def score(self, examples, **kwargs):
|
||||
"""Score a batch of examples."""
|
||||
# TODO This is basically the same as the main coref component - factor out?
|
||||
|
||||
"""
|
||||
Evaluate on reconstructing the correct spans around
|
||||
gold heads.
|
||||
"""
|
||||
scores = []
|
||||
for metric in (b_cubed, muc, ceafe):
|
||||
evaluator = Evaluator(metric)
|
||||
for eg in examples:
|
||||
starts = []
|
||||
ends = []
|
||||
pred_starts = []
|
||||
pred_ends = []
|
||||
ref = eg.reference
|
||||
pred = eg.predicted
|
||||
for key, gold_sg in ref.spans.items():
|
||||
if key.startswith(self.input_prefix):
|
||||
cluster_id = key.split('_')[-1]
|
||||
# FIXME THIS DOESN'T WORK BECAUSE pred.spans are empty?
|
||||
pred_sg = pred.spans[f"{self.output_prefix}_{cluster_id}"]
|
||||
for gold_mention, pred_mention in zip(gold_sg, pred_sg):
|
||||
starts.append(gold_mention.start)
|
||||
ends.append(gold_mention.end)
|
||||
pred_starts.append(pred_mention.start)
|
||||
pred_ends.append(pred_mention.end)
|
||||
|
||||
for ex in examples:
|
||||
# XXX this is the only different part
|
||||
p_clusters = doc2clusters(ex.predicted, self.output_prefix)
|
||||
g_clusters = doc2clusters(ex.reference, self.output_prefix)
|
||||
cluster_info = get_cluster_info(p_clusters, g_clusters)
|
||||
starts = self.model.ops.xp.asarray(starts)
|
||||
ends = self.model.ops.xp.asarray(ends)
|
||||
pred_starts = self.model.ops.xp.asarray(pred_starts)
|
||||
pred_ends = self.model.ops.xp.asarray(pred_ends)
|
||||
correct = ((starts == pred_starts) * (ends == pred_ends)).sum()
|
||||
scores.append(correct)
|
||||
|
||||
evaluator.update(cluster_info)
|
||||
|
||||
score = {
|
||||
"coref_span_f": evaluator.get_f1(),
|
||||
"coref_span_p": evaluator.get_precision(),
|
||||
"coref_span_r": evaluator.get_recall(),
|
||||
}
|
||||
scores.append(score)
|
||||
|
||||
out = {}
|
||||
for field in ("f", "p", "r"):
|
||||
fname = f"coref_span_{field}"
|
||||
out[fname] = mean([ss[fname] for ss in scores])
|
||||
out = {"span_accuracy": mean(scores)}
|
||||
return out
|
||||
|
|
Loading…
Reference in New Issue
Block a user