mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
handle empty clusters
This commit is contained in:
parent
4fc40340f9
commit
e4b4b67ef6
|
@ -458,27 +458,29 @@ class SpanPredictor(TrainablePipe):
|
|||
out = []
|
||||
for doc in docs:
|
||||
# TODO check shape here
|
||||
span_scores = self.model.predict(doc)
|
||||
span_scores = span_scores[0]
|
||||
# the information about clustering has to come from the input docs
|
||||
# first let's convert the scores to a list of span idxs
|
||||
start_scores = span_scores[:, :, 0]
|
||||
end_scores = span_scores[:, :, 1]
|
||||
starts = start_scores.argmax(axis=1)
|
||||
ends = end_scores.argmax(axis=1)
|
||||
span_scores = self.model.predict([doc])
|
||||
if span_scores.size:
|
||||
# the information about clustering has to come from the input docs
|
||||
# first let's convert the scores to a list of span idxs
|
||||
start_scores = span_scores[:, :, 0]
|
||||
end_scores = span_scores[:, :, 1]
|
||||
starts = start_scores.argmax(axis=1)
|
||||
ends = end_scores.argmax(axis=1)
|
||||
|
||||
# TODO check start < end
|
||||
# TODO check start < end
|
||||
|
||||
# get the old clusters (shape will be preserved)
|
||||
clusters = doc2clusters(doc, self.input_prefix)
|
||||
cidx = 0
|
||||
out_clusters = []
|
||||
for cluster in clusters:
|
||||
ncluster = []
|
||||
for mention in cluster:
|
||||
ncluster.append( (starts[cidx], ends[cidx]) )
|
||||
cidx += 1
|
||||
out_clusters.append(ncluster)
|
||||
# get the old clusters (shape will be preserved)
|
||||
clusters = doc2clusters(doc, self.input_prefix)
|
||||
cidx = 0
|
||||
out_clusters = []
|
||||
for cluster in clusters:
|
||||
ncluster = []
|
||||
for mention in cluster:
|
||||
ncluster.append((starts[cidx], ends[cidx]))
|
||||
cidx += 1
|
||||
out_clusters.append(ncluster)
|
||||
else:
|
||||
out_clusters = []
|
||||
out.append(out_clusters)
|
||||
return out
|
||||
|
||||
|
@ -628,7 +630,6 @@ class SpanPredictor(TrainablePipe):
|
|||
# XXX this is the only different part
|
||||
p_clusters = doc2clusters(ex.predicted, self.output_prefix)
|
||||
g_clusters = doc2clusters(ex.reference, self.output_prefix)
|
||||
|
||||
cluster_info = get_cluster_info(p_clusters, g_clusters)
|
||||
|
||||
evaluator.update(cluster_info)
|
||||
|
|
Loading…
Reference in New Issue
Block a user