mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-11 00:32:40 +03:00
bugfix
This commit is contained in:
parent
ec941a128d
commit
43162029bc
|
@ -731,25 +731,23 @@ class SpanCategorizer(TrainablePipe):
|
|||
# Filter samples according to threshold.
|
||||
threshold = self.cfg["threshold"]
|
||||
if threshold is not None:
|
||||
print(argmax_scores >= threshold)
|
||||
keeps = numpy.logical_and(keeps, argmax_scores >= threshold)
|
||||
keeps = numpy.logical_and(keeps, (argmax_scores >= threshold).squeeze())
|
||||
# Sort spans according to argmax probability
|
||||
if not allow_overlap:
|
||||
# Get the probabilities
|
||||
sort_idx = (argmax_scores * -1).argsort()
|
||||
sort_idx = (argmax_scores.squeeze() * -1).argsort()
|
||||
predicted = predicted[sort_idx]
|
||||
indices = indices[sort_idx]
|
||||
|
||||
# TODO assigns spans.attrs["scores"]
|
||||
keeps = keeps[sort_idx]
|
||||
seen = Intervals()
|
||||
spans = SpanGroup(doc, name=self.key)
|
||||
attrs_scores = []
|
||||
for i in range(indices.shape[0]):
|
||||
if not keeps[i]:
|
||||
continue
|
||||
label = predicted[i]
|
||||
start = indices[i, 0]
|
||||
end = indices[i, 1]
|
||||
if not keeps[i]:
|
||||
continue
|
||||
|
||||
if not allow_overlap:
|
||||
if (start, end) in seen:
|
||||
|
|
Loading…
Reference in New Issue
Block a user