This commit is contained in:
kadarakos 2023-02-08 19:43:51 +00:00
parent ec941a128d
commit 43162029bc

View File

@ -731,25 +731,23 @@ class SpanCategorizer(TrainablePipe):
# Filter samples according to threshold.
threshold = self.cfg["threshold"]
if threshold is not None:
print(argmax_scores >= threshold)
keeps = numpy.logical_and(keeps, argmax_scores >= threshold)
keeps = numpy.logical_and(keeps, (argmax_scores >= threshold).squeeze())
# Sort spans according to argmax probability
if not allow_overlap:
# Get the probabilities
sort_idx = (argmax_scores * -1).argsort()
sort_idx = (argmax_scores.squeeze() * -1).argsort()
predicted = predicted[sort_idx]
indices = indices[sort_idx]
# TODO assigns spans.attrs["scores"]
keeps = keeps[sort_idx]
seen = Intervals()
spans = SpanGroup(doc, name=self.key)
attrs_scores = []
for i in range(indices.shape[0]):
if not keeps[i]:
continue
label = predicted[i]
start = indices[i, 0]
end = indices[i, 1]
if not keeps[i]:
continue
if not allow_overlap:
if (start, end) in seen: