mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-11 16:52:21 +03:00
bugfix
This commit is contained in:
parent
ec941a128d
commit
43162029bc
|
@ -731,25 +731,23 @@ class SpanCategorizer(TrainablePipe):
|
||||||
# Filter samples according to threshold.
|
# Filter samples according to threshold.
|
||||||
threshold = self.cfg["threshold"]
|
threshold = self.cfg["threshold"]
|
||||||
if threshold is not None:
|
if threshold is not None:
|
||||||
print(argmax_scores >= threshold)
|
keeps = numpy.logical_and(keeps, (argmax_scores >= threshold).squeeze())
|
||||||
keeps = numpy.logical_and(keeps, argmax_scores >= threshold)
|
|
||||||
# Sort spans according to argmax probability
|
# Sort spans according to argmax probability
|
||||||
if not allow_overlap:
|
if not allow_overlap:
|
||||||
# Get the probabilities
|
# Get the probabilities
|
||||||
sort_idx = (argmax_scores * -1).argsort()
|
sort_idx = (argmax_scores.squeeze() * -1).argsort()
|
||||||
predicted = predicted[sort_idx]
|
predicted = predicted[sort_idx]
|
||||||
indices = indices[sort_idx]
|
indices = indices[sort_idx]
|
||||||
|
keeps = keeps[sort_idx]
|
||||||
# TODO assigns spans.attrs["scores"]
|
|
||||||
seen = Intervals()
|
seen = Intervals()
|
||||||
spans = SpanGroup(doc, name=self.key)
|
spans = SpanGroup(doc, name=self.key)
|
||||||
attrs_scores = []
|
attrs_scores = []
|
||||||
for i in range(indices.shape[0]):
|
for i in range(indices.shape[0]):
|
||||||
if not keeps[i]:
|
|
||||||
continue
|
|
||||||
label = predicted[i]
|
label = predicted[i]
|
||||||
start = indices[i, 0]
|
start = indices[i, 0]
|
||||||
end = indices[i, 1]
|
end = indices[i, 1]
|
||||||
|
if not keeps[i]:
|
||||||
|
continue
|
||||||
|
|
||||||
if not allow_overlap:
|
if not allow_overlap:
|
||||||
if (start, end) in seen:
|
if (start, end) in seen:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user