mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
handle empty clusters
This commit is contained in:
parent
4fc40340f9
commit
e4b4b67ef6
|
@ -458,8 +458,8 @@ class SpanPredictor(TrainablePipe):
|
|||
out = []
|
||||
for doc in docs:
|
||||
# TODO check shape here
|
||||
span_scores = self.model.predict(doc)
|
||||
span_scores = span_scores[0]
|
||||
span_scores = self.model.predict([doc])
|
||||
if span_scores.size:
|
||||
# the information about clustering has to come from the input docs
|
||||
# first let's convert the scores to a list of span idxs
|
||||
start_scores = span_scores[:, :, 0]
|
||||
|
@ -479,6 +479,8 @@ class SpanPredictor(TrainablePipe):
|
|||
ncluster.append((starts[cidx], ends[cidx]))
|
||||
cidx += 1
|
||||
out_clusters.append(ncluster)
|
||||
else:
|
||||
out_clusters = []
|
||||
out.append(out_clusters)
|
||||
return out
|
||||
|
||||
|
@ -628,7 +630,6 @@ class SpanPredictor(TrainablePipe):
|
|||
# XXX this is the only different part
|
||||
p_clusters = doc2clusters(ex.predicted, self.output_prefix)
|
||||
g_clusters = doc2clusters(ex.reference, self.output_prefix)
|
||||
|
||||
cluster_info = get_cluster_info(p_clusters, g_clusters)
|
||||
|
||||
evaluator.update(cluster_info)
|
||||
|
|
Loading…
Reference in New Issue
Block a user