revert to not interleaving (relized its faster)

This commit is contained in:
kadarakos 2023-05-03 13:32:41 +00:00
parent a5b9e63664
commit 82f6a813f0

View File

@ -229,21 +229,23 @@ class SpanFinder(TrainablePipe):
offset = 0 offset = 0
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
doc.spans[self.predicted_key] = [] doc.spans[self.predicted_key] = []
doc_scores = scores[offset : offset + len(doc)] starts = []
for j in range(len(doc)): ends = []
start_token_score = doc_scores[j] doc_scores = scores[offset:offset + len(doc)]
# If token is a START then start scanning following tokens
if start_token_score[0] >= self.threshold: for token, token_score in zip(doc, doc_scores):
for k in range(j, len(doc)): if token_score[0] >= self.threshold:
end_token_score = doc_scores[k] starts.append(token.i)
# If token is an END check whether the lenght contraint is met if token_score[1] >= self.threshold:
if end_token_score[1] >= self.threshold: ends.append(token.i)
span_length = k + 1 - j
if span_length > self.max_length: for start in starts:
break for end in ends:
elif self.min_length <= span_length: span_length = end + 1 - start
span = doc[j : k + 1] if span_length > self.max_length:
doc.spans[self.predicted_key].append(span) break
elif self.min_length <= span_length:
doc.spans[self.predicted_key].append(doc[start : end + 1])
def update( def update(
self, self,