handle empty clusters

This commit is contained in:
Kádár Ákos 2022-03-28 11:29:00 +02:00
parent 4fc40340f9
commit e4b4b67ef6

View File

@ -458,8 +458,8 @@ class SpanPredictor(TrainablePipe):
out = [] out = []
for doc in docs: for doc in docs:
# TODO check shape here # TODO check shape here
span_scores = self.model.predict(doc) span_scores = self.model.predict([doc])
span_scores = span_scores[0] if span_scores.size:
# the information about clustering has to come from the input docs # the information about clustering has to come from the input docs
# first let's convert the scores to a list of span idxs # first let's convert the scores to a list of span idxs
start_scores = span_scores[:, :, 0] start_scores = span_scores[:, :, 0]
@ -476,9 +476,11 @@ class SpanPredictor(TrainablePipe):
for cluster in clusters: for cluster in clusters:
ncluster = [] ncluster = []
for mention in cluster: for mention in cluster:
ncluster.append( (starts[cidx], ends[cidx]) ) ncluster.append((starts[cidx], ends[cidx]))
cidx += 1 cidx += 1
out_clusters.append(ncluster) out_clusters.append(ncluster)
else:
out_clusters = []
out.append(out_clusters) out.append(out_clusters)
return out return out
@ -628,7 +630,6 @@ class SpanPredictor(TrainablePipe):
# XXX this is the only different part # XXX this is the only different part
p_clusters = doc2clusters(ex.predicted, self.output_prefix) p_clusters = doc2clusters(ex.predicted, self.output_prefix)
g_clusters = doc2clusters(ex.reference, self.output_prefix) g_clusters = doc2clusters(ex.reference, self.output_prefix)
cluster_info = get_cluster_info(p_clusters, g_clusters) cluster_info = get_cluster_info(p_clusters, g_clusters)
evaluator.update(cluster_info) evaluator.update(cluster_info)