From e4b4b67ef6f627f7cd9cd313ab9274779c16c971 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A1d=C3=A1r=20=C3=81kos?= Date: Mon, 28 Mar 2022 11:29:00 +0200 Subject: [PATCH] handle empty clusters --- spacy/pipeline/coref.py | 41 +++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/spacy/pipeline/coref.py b/spacy/pipeline/coref.py index 99bb611ff..5a4fa1ab9 100644 --- a/spacy/pipeline/coref.py +++ b/spacy/pipeline/coref.py @@ -458,27 +458,29 @@ class SpanPredictor(TrainablePipe): out = [] for doc in docs: # TODO check shape here - span_scores = self.model.predict(doc) - span_scores = span_scores[0] - # the information about clustering has to come from the input docs - # first let's convert the scores to a list of span idxs - start_scores = span_scores[:, :, 0] - end_scores = span_scores[:, :, 1] - starts = start_scores.argmax(axis=1) - ends = end_scores.argmax(axis=1) + span_scores = self.model.predict([doc]) + if span_scores.size: + # the information about clustering has to come from the input docs + # first let's convert the scores to a list of span idxs + start_scores = span_scores[:, :, 0] + end_scores = span_scores[:, :, 1] + starts = start_scores.argmax(axis=1) + ends = end_scores.argmax(axis=1) - # TODO check start < end + # TODO check start < end - # get the old clusters (shape will be preserved) - clusters = doc2clusters(doc, self.input_prefix) - cidx = 0 - out_clusters = [] - for cluster in clusters: - ncluster = [] - for mention in cluster: - ncluster.append( (starts[cidx], ends[cidx]) ) - cidx += 1 - out_clusters.append(ncluster) + # get the old clusters (shape will be preserved) + clusters = doc2clusters(doc, self.input_prefix) + cidx = 0 + out_clusters = [] + for cluster in clusters: + ncluster = [] + for mention in cluster: + ncluster.append((starts[cidx], ends[cidx])) + cidx += 1 + out_clusters.append(ncluster) + else: + out_clusters = [] out.append(out_clusters) return out @@ -628,7 +630,6 @@ class SpanPredictor(TrainablePipe): # XXX this is the only different part p_clusters = doc2clusters(ex.predicted, self.output_prefix) g_clusters = doc2clusters(ex.reference, self.output_prefix) - cluster_info = get_cluster_info(p_clusters, g_clusters) evaluator.update(cluster_info)