interleave thresholding with span creation

This commit is contained in:
kadarakos 2023-05-03 13:26:31 +00:00
parent 6b2e8363fc
commit db361db874

View File

@ -229,23 +229,21 @@ class SpanFinder(TrainablePipe):
offset = 0 offset = 0
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
doc.spans[self.predicted_key] = [] doc.spans[self.predicted_key] = []
starts = []
ends = []
doc_scores = scores[offset:offset + len(doc)] doc_scores = scores[offset:offset + len(doc)]
for j in range(len(doc)):
for token, token_score in zip(doc, doc_scores): start_token_score = doc_scores[j]
if token_score[0] >= self.threshold: # If token is a START then start scanning following tokens
starts.append(token.i) if start_token_score[0] >= self.threshold:
if token_score[1] >= self.threshold: for k in range(j, len(doc)):
ends.append(token.i) end_token_score = doc_scores[k]
# If token is an END check whether the lenght contraint is met
for start in starts: if end_token_score[1] >= self.threshold:
for end in ends: span_length = (k + 1 - j)
span_length = end + 1 - start if span_length > self.max_length:
if span_length > self.max_length: break
break elif self.min_length <= span_length:
elif self.min_length <= span_length: span = doc[j:k + 1]
doc.spans[self.predicted_key].append(doc[start : end + 1]) doc.spans[self.predicted_key].append(span)
def update( def update(
self, self,