diff --git a/spacy/pipeline/span_finder.py b/spacy/pipeline/span_finder.py index d767c6e18..3e44b8d01 100644 --- a/spacy/pipeline/span_finder.py +++ b/spacy/pipeline/span_finder.py @@ -117,20 +117,6 @@ def span_finder_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: return scores -class _MaxInt(int): - def __le__(self, other): - return False - - def __lt__(self, other): - return False - - def __ge__(self, other): - return True - - def __gt__(self, other): - return True - - def _char_indices(span: Span) -> Tuple[int, int]: start = span[0].idx end = span[-1].idx + len(span[-1]) @@ -163,25 +149,24 @@ class SpanFinder(TrainablePipe): initialization and training, the component will look for spans on the reference document under the same key. max_length (Optional[int]): Max length of the produced spans, defaults to None meaning unlimited length. - min_length (Optional[int]): Min length of the produced spans, defaults to None meaining shortest span is length 1. + min_length (Optional[int]): Min length of the produced spans, defaults to None meaning shortest span is length 1. """ self.vocab = nlp.vocab - self.threshold = threshold - if max_length is None: - max_length = _MaxInt() - if min_length is None: - min_length = 1 - if max_length < 1 or min_length < 1: + if (max_length is not None and max_length < 1) or ( + min_length is not None and min_length < 1 + ): raise ValueError( Errors.E1053.format(min_length=min_length, max_length=max_length) ) - self.min_length = min_length - self.max_length = max_length - self.spans_key = spans_key self.model = model self.name = name self.scorer = scorer - self.cfg = {"spans_key": spans_key} + self.cfg: Dict[str, Any] = { + "min_length": min_length, + "max_length": max_length, + "threshold": threshold, + "spans_key": spans_key, + } def predict(self, docs: Iterable[Doc]): """Apply the pipeline's model to a batch of docs, without modifying them. @@ -198,24 +183,31 @@ class SpanFinder(TrainablePipe): """ offset = 0 for i, doc in enumerate(docs): - doc.spans[self.spans_key] = [] + doc.spans[self.cfg["spans_key"]] = [] starts = [] ends = [] doc_scores = scores[offset : offset + len(doc)] for token, token_score in zip(doc, doc_scores): - if token_score[0] >= self.threshold: + if token_score[0] >= self.cfg["threshold"]: starts.append(token.i) - if token_score[1] >= self.threshold: + if token_score[1] >= self.cfg["threshold"]: ends.append(token.i) + print(self.cfg) for start in starts: for end in ends: span_length = end + 1 - start - if span_length > self.max_length: - break - elif self.min_length <= span_length: - doc.spans[self.spans_key].append(doc[start : end + 1]) + if span_length < 1: + continue + if ( + self.cfg["min_length"] is None + or self.cfg["min_length"] <= span_length + ) and ( + self.cfg["max_length"] is None + or span_length <= self.cfg["max_length"] + ): + doc.spans[self.cfg["spans_key"]].append(doc[start : end + 1]) offset += len(doc) def update( @@ -271,8 +263,8 @@ class SpanFinder(TrainablePipe): n_tokens = len(eg.predicted) truth = ops.xp.zeros((n_tokens, 2), dtype="float32") mask = ops.xp.ones((n_tokens, 2), dtype="float32") - if self.spans_key in eg.reference.spans: - for span in eg.reference.spans[self.spans_key]: + if self.cfg["spans_key"] in eg.reference.spans: + for span in eg.reference.spans[self.cfg["spans_key"]]: ref_start_char, ref_end_char = _char_indices(span) pred_span = eg.predicted.char_span( ref_start_char, ref_end_char, alignment_mode="expand" diff --git a/spacy/tests/pipeline/test_span_finder.py b/spacy/tests/pipeline/test_span_finder.py index 9c6b6c3e0..91b08cabf 100644 --- a/spacy/tests/pipeline/test_span_finder.py +++ b/spacy/tests/pipeline/test_span_finder.py @@ -209,7 +209,7 @@ def test_overfitting_IO(): # test the trained model test_text = "I like London and Berlin" doc = nlp(test_text) - spans = doc.spans[span_finder.spans_key] + spans = doc.spans[SPANS_KEY] assert len(spans) == 3 assert set([span.text for span in spans]) == { "London", @@ -222,7 +222,7 @@ def test_overfitting_IO(): nlp.to_disk(tmp_dir) nlp2 = util.load_model_from_path(tmp_dir) doc2 = nlp2(test_text) - spans2 = doc2.spans[span_finder.spans_key] + spans2 = doc2.spans[SPANS_KEY] assert len(spans2) == 3 assert set([span.text for span in spans2]) == { "London", @@ -232,11 +232,11 @@ def test_overfitting_IO(): # Test scoring scores = nlp.evaluate(train_examples) - assert f"span_finder_{span_finder.spans_key}_f" in scores + assert f"span_finder_{SPANS_KEY}_f" in scores # It's not perfect 1.0 F1 because it's designed to overgenerate for now. - assert scores[f"span_finder_{span_finder.spans_key}_p"] == 0.75 - assert scores[f"span_finder_{span_finder.spans_key}_r"] == 1.0 + assert scores[f"span_finder_{SPANS_KEY}_p"] == 0.75 + assert scores[f"span_finder_{SPANS_KEY}_r"] == 1.0 # also test that the spancat works for just a single entity in a sentence doc = nlp("London") - assert len(doc.spans[span_finder.spans_key]) == 1 + assert len(doc.spans[SPANS_KEY]) == 1