mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 20:22:25 +03:00
handle empty clusters
This commit is contained in:
parent
4fc40340f9
commit
e4b4b67ef6
|
@ -458,27 +458,29 @@ class SpanPredictor(TrainablePipe):
|
||||||
out = []
|
out = []
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
# TODO check shape here
|
# TODO check shape here
|
||||||
span_scores = self.model.predict(doc)
|
span_scores = self.model.predict([doc])
|
||||||
span_scores = span_scores[0]
|
if span_scores.size:
|
||||||
# the information about clustering has to come from the input docs
|
# the information about clustering has to come from the input docs
|
||||||
# first let's convert the scores to a list of span idxs
|
# first let's convert the scores to a list of span idxs
|
||||||
start_scores = span_scores[:, :, 0]
|
start_scores = span_scores[:, :, 0]
|
||||||
end_scores = span_scores[:, :, 1]
|
end_scores = span_scores[:, :, 1]
|
||||||
starts = start_scores.argmax(axis=1)
|
starts = start_scores.argmax(axis=1)
|
||||||
ends = end_scores.argmax(axis=1)
|
ends = end_scores.argmax(axis=1)
|
||||||
|
|
||||||
# TODO check start < end
|
# TODO check start < end
|
||||||
|
|
||||||
# get the old clusters (shape will be preserved)
|
# get the old clusters (shape will be preserved)
|
||||||
clusters = doc2clusters(doc, self.input_prefix)
|
clusters = doc2clusters(doc, self.input_prefix)
|
||||||
cidx = 0
|
cidx = 0
|
||||||
out_clusters = []
|
out_clusters = []
|
||||||
for cluster in clusters:
|
for cluster in clusters:
|
||||||
ncluster = []
|
ncluster = []
|
||||||
for mention in cluster:
|
for mention in cluster:
|
||||||
ncluster.append( (starts[cidx], ends[cidx]) )
|
ncluster.append((starts[cidx], ends[cidx]))
|
||||||
cidx += 1
|
cidx += 1
|
||||||
out_clusters.append(ncluster)
|
out_clusters.append(ncluster)
|
||||||
|
else:
|
||||||
|
out_clusters = []
|
||||||
out.append(out_clusters)
|
out.append(out_clusters)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@ -628,7 +630,6 @@ class SpanPredictor(TrainablePipe):
|
||||||
# XXX this is the only different part
|
# XXX this is the only different part
|
||||||
p_clusters = doc2clusters(ex.predicted, self.output_prefix)
|
p_clusters = doc2clusters(ex.predicted, self.output_prefix)
|
||||||
g_clusters = doc2clusters(ex.reference, self.output_prefix)
|
g_clusters = doc2clusters(ex.reference, self.output_prefix)
|
||||||
|
|
||||||
cluster_info = get_cluster_info(p_clusters, g_clusters)
|
cluster_info = get_cluster_info(p_clusters, g_clusters)
|
||||||
|
|
||||||
evaluator.update(cluster_info)
|
evaluator.update(cluster_info)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user