diff --git a/spacy/pipeline/span_finder.py b/spacy/pipeline/span_finder.py index f06ea731a..bdfa055f3 100644 --- a/spacy/pipeline/span_finder.py +++ b/spacy/pipeline/span_finder.py @@ -229,21 +229,23 @@ class SpanFinder(TrainablePipe): offset = 0 for i, doc in enumerate(docs): doc.spans[self.predicted_key] = [] - doc_scores = scores[offset : offset + len(doc)] - for j in range(len(doc)): - start_token_score = doc_scores[j] - # If token is a START then start scanning following tokens - if start_token_score[0] >= self.threshold: - for k in range(j, len(doc)): - end_token_score = doc_scores[k] - # If token is an END check whether the lenght contraint is met - if end_token_score[1] >= self.threshold: - span_length = k + 1 - j - if span_length > self.max_length: - break - elif self.min_length <= span_length: - span = doc[j : k + 1] - doc.spans[self.predicted_key].append(span) + starts = [] + ends = [] + doc_scores = scores[offset:offset + len(doc)] + + for token, token_score in zip(doc, doc_scores): + if token_score[0] >= self.threshold: + starts.append(token.i) + if token_score[1] >= self.threshold: + ends.append(token.i) + + for start in starts: + for end in ends: + span_length = end + 1 - start + if span_length > self.max_length: + break + elif self.min_length <= span_length: + doc.spans[self.predicted_key].append(doc[start : end + 1]) def update( self,