From db361db874ff31355afc2b189529328b1ee02006 Mon Sep 17 00:00:00 2001 From: kadarakos Date: Wed, 3 May 2023 13:26:31 +0000 Subject: [PATCH] interleave thresholding with span creation --- spacy/pipeline/span_finder.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/spacy/pipeline/span_finder.py b/spacy/pipeline/span_finder.py index bdfa055f3..9f2ff1a78 100644 --- a/spacy/pipeline/span_finder.py +++ b/spacy/pipeline/span_finder.py @@ -229,23 +229,21 @@ class SpanFinder(TrainablePipe): offset = 0 for i, doc in enumerate(docs): doc.spans[self.predicted_key] = [] - starts = [] - ends = [] doc_scores = scores[offset:offset + len(doc)] - - for token, token_score in zip(doc, doc_scores): - if token_score[0] >= self.threshold: - starts.append(token.i) - if token_score[1] >= self.threshold: - ends.append(token.i) - - for start in starts: - for end in ends: - span_length = end + 1 - start - if span_length > self.max_length: - break - elif self.min_length <= span_length: - doc.spans[self.predicted_key].append(doc[start : end + 1]) + for j in range(len(doc)): + start_token_score = doc_scores[j] + # If token is a START then start scanning following tokens + if start_token_score[0] >= self.threshold: + for k in range(j, len(doc)): + end_token_score = doc_scores[k] + # If token is an END check whether the lenght contraint is met + if end_token_score[1] >= self.threshold: + span_length = (k + 1 - j) + if span_length > self.max_length: + break + elif self.min_length <= span_length: + span = doc[j:k + 1] + doc.spans[self.predicted_key].append(span) def update( self,