handle empty clusters

This commit is contained in:
Kádár Ákos 2022-03-28 11:29:00 +02:00
parent 4fc40340f9
commit e4b4b67ef6

View File

@ -458,27 +458,29 @@ class SpanPredictor(TrainablePipe):
out = []
for doc in docs:
# TODO check shape here
span_scores = self.model.predict(doc)
span_scores = span_scores[0]
# the information about clustering has to come from the input docs
# first let's convert the scores to a list of span idxs
start_scores = span_scores[:, :, 0]
end_scores = span_scores[:, :, 1]
starts = start_scores.argmax(axis=1)
ends = end_scores.argmax(axis=1)
span_scores = self.model.predict([doc])
if span_scores.size:
# the information about clustering has to come from the input docs
# first let's convert the scores to a list of span idxs
start_scores = span_scores[:, :, 0]
end_scores = span_scores[:, :, 1]
starts = start_scores.argmax(axis=1)
ends = end_scores.argmax(axis=1)
# TODO check start < end
# TODO check start < end
# get the old clusters (shape will be preserved)
clusters = doc2clusters(doc, self.input_prefix)
cidx = 0
out_clusters = []
for cluster in clusters:
ncluster = []
for mention in cluster:
ncluster.append( (starts[cidx], ends[cidx]) )
cidx += 1
out_clusters.append(ncluster)
# get the old clusters (shape will be preserved)
clusters = doc2clusters(doc, self.input_prefix)
cidx = 0
out_clusters = []
for cluster in clusters:
ncluster = []
for mention in cluster:
ncluster.append((starts[cidx], ends[cidx]))
cidx += 1
out_clusters.append(ncluster)
else:
out_clusters = []
out.append(out_clusters)
return out
@ -628,7 +630,6 @@ class SpanPredictor(TrainablePipe):
# XXX this is the only different part
p_clusters = doc2clusters(ex.predicted, self.output_prefix)
g_clusters = doc2clusters(ex.reference, self.output_prefix)
cluster_info = get_cluster_info(p_clusters, g_clusters)
evaluator.update(cluster_info)