mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 20:22:25 +03:00
span accuracy score
This commit is contained in:
parent
a1d0219903
commit
ef141ad399
|
@ -457,6 +457,7 @@ class SpanPredictor(TrainablePipe):
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
# TODO check shape here
|
# TODO check shape here
|
||||||
span_scores = self.model.predict([doc])
|
span_scores = self.model.predict([doc])
|
||||||
|
print(span_scores)
|
||||||
if span_scores.size:
|
if span_scores.size:
|
||||||
# the information about clustering has to come from the input docs
|
# the information about clustering has to come from the input docs
|
||||||
# first let's convert the scores to a list of span idxs
|
# first let's convert the scores to a list of span idxs
|
||||||
|
@ -608,30 +609,35 @@ class SpanPredictor(TrainablePipe):
|
||||||
self.model.initialize(X=X, Y=Y)
|
self.model.initialize(X=X, Y=Y)
|
||||||
|
|
||||||
def score(self, examples, **kwargs):
|
def score(self, examples, **kwargs):
|
||||||
"""Score a batch of examples."""
|
"""
|
||||||
# TODO This is basically the same as the main coref component - factor out?
|
Evaluate on reconstructing the correct spans around
|
||||||
|
gold heads.
|
||||||
|
"""
|
||||||
scores = []
|
scores = []
|
||||||
for metric in (b_cubed, muc, ceafe):
|
for eg in examples:
|
||||||
evaluator = Evaluator(metric)
|
starts = []
|
||||||
|
ends = []
|
||||||
|
pred_starts = []
|
||||||
|
pred_ends = []
|
||||||
|
ref = eg.reference
|
||||||
|
pred = eg.predicted
|
||||||
|
for key, gold_sg in ref.spans.items():
|
||||||
|
if key.startswith(self.input_prefix):
|
||||||
|
cluster_id = key.split('_')[-1]
|
||||||
|
# FIXME THIS DOESN'T WORK BECAUSE pred.spans are empty?
|
||||||
|
pred_sg = pred.spans[f"{self.output_prefix}_{cluster_id}"]
|
||||||
|
for gold_mention, pred_mention in zip(gold_sg, pred_sg):
|
||||||
|
starts.append(gold_mention.start)
|
||||||
|
ends.append(gold_mention.end)
|
||||||
|
pred_starts.append(pred_mention.start)
|
||||||
|
pred_ends.append(pred_mention.end)
|
||||||
|
|
||||||
for ex in examples:
|
starts = self.model.ops.xp.asarray(starts)
|
||||||
# XXX this is the only different part
|
ends = self.model.ops.xp.asarray(ends)
|
||||||
p_clusters = doc2clusters(ex.predicted, self.output_prefix)
|
pred_starts = self.model.ops.xp.asarray(pred_starts)
|
||||||
g_clusters = doc2clusters(ex.reference, self.output_prefix)
|
pred_ends = self.model.ops.xp.asarray(pred_ends)
|
||||||
cluster_info = get_cluster_info(p_clusters, g_clusters)
|
correct = ((starts == pred_starts) * (ends == pred_ends)).sum()
|
||||||
|
scores.append(correct)
|
||||||
|
|
||||||
evaluator.update(cluster_info)
|
out = {"span_accuracy": mean(scores)}
|
||||||
|
|
||||||
score = {
|
|
||||||
"coref_span_f": evaluator.get_f1(),
|
|
||||||
"coref_span_p": evaluator.get_precision(),
|
|
||||||
"coref_span_r": evaluator.get_recall(),
|
|
||||||
}
|
|
||||||
scores.append(score)
|
|
||||||
|
|
||||||
out = {}
|
|
||||||
for field in ("f", "p", "r"):
|
|
||||||
fname = f"coref_span_{field}"
|
|
||||||
out[fname] = mean([ss[fname] for ss in scores])
|
|
||||||
return out
|
return out
|
||||||
|
|
Loading…
Reference in New Issue
Block a user