mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-29 06:57:49 +03:00 
			
		
		
		
	span accuracy score
This commit is contained in:
		
							parent
							
								
									a1d0219903
								
							
						
					
					
						commit
						ef141ad399
					
				|  | @ -457,6 +457,7 @@ class SpanPredictor(TrainablePipe): | |||
|         for doc in docs: | ||||
|             # TODO check shape here | ||||
|             span_scores = self.model.predict([doc]) | ||||
|             print(span_scores) | ||||
|             if span_scores.size: | ||||
|                 # the information about clustering has to come from the input docs | ||||
|                 # first let's convert the scores to a list of span idxs | ||||
|  | @ -608,30 +609,35 @@ class SpanPredictor(TrainablePipe): | |||
|         self.model.initialize(X=X, Y=Y) | ||||
| 
 | ||||
|     def score(self, examples, **kwargs): | ||||
|         """Score a batch of examples.""" | ||||
|         # TODO This is basically the same as the main coref component - factor out? | ||||
| 
 | ||||
|         """ | ||||
|         Evaluate on reconstructing the correct spans around | ||||
|         gold heads. | ||||
|         """ | ||||
|         scores = [] | ||||
|         for metric in (b_cubed, muc, ceafe): | ||||
|             evaluator = Evaluator(metric) | ||||
|         for eg in examples: | ||||
|             starts = [] | ||||
|             ends = [] | ||||
|             pred_starts = [] | ||||
|             pred_ends = [] | ||||
|             ref = eg.reference | ||||
|             pred = eg.predicted | ||||
|             for key, gold_sg in ref.spans.items(): | ||||
|                 if key.startswith(self.input_prefix): | ||||
|                     cluster_id = key.split('_')[-1] | ||||
|                     # FIXME THIS DOESN'T WORK BECAUSE pred.spans are empty? | ||||
|                     pred_sg = pred.spans[f"{self.output_prefix}_{cluster_id}"] | ||||
|                     for gold_mention, pred_mention in zip(gold_sg, pred_sg): | ||||
|                         starts.append(gold_mention.start) | ||||
|                         ends.append(gold_mention.end) | ||||
|                         pred_starts.append(pred_mention.start) | ||||
|                         pred_ends.append(pred_mention.end) | ||||
| 
 | ||||
|             for ex in examples: | ||||
|                 # XXX this is the only different part | ||||
|                 p_clusters = doc2clusters(ex.predicted, self.output_prefix) | ||||
|                 g_clusters = doc2clusters(ex.reference, self.output_prefix) | ||||
|                 cluster_info = get_cluster_info(p_clusters, g_clusters) | ||||
|             starts = self.model.ops.xp.asarray(starts) | ||||
|             ends = self.model.ops.xp.asarray(ends) | ||||
|             pred_starts = self.model.ops.xp.asarray(pred_starts) | ||||
|             pred_ends = self.model.ops.xp.asarray(pred_ends) | ||||
|             correct = ((starts == pred_starts) * (ends == pred_ends)).sum() | ||||
|             scores.append(correct) | ||||
| 
 | ||||
|                 evaluator.update(cluster_info) | ||||
| 
 | ||||
|             score = { | ||||
|                 "coref_span_f": evaluator.get_f1(), | ||||
|                 "coref_span_p": evaluator.get_precision(), | ||||
|                 "coref_span_r": evaluator.get_recall(), | ||||
|             } | ||||
|             scores.append(score) | ||||
| 
 | ||||
|         out = {} | ||||
|         for field in ("f", "p", "r"): | ||||
|             fname = f"coref_span_{field}" | ||||
|             out[fname] = mean([ss[fname] for ss in scores]) | ||||
|         out = {"span_accuracy": mean(scores)} | ||||
|         return out | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user