diff --git a/spacy/errors.py b/spacy/errors.py index 40cfa8d92..9a8cf63c9 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -970,6 +970,8 @@ class Errors(metaclass=ErrorsWithCodes): E1050 = ("Port {port} is already in use. Please specify an available port with `displacy.serve(doc, port=port)` " "or use `auto_select_port=True` to pick an available port automatically.") E1051 = ("'allow_overlap' can only be False when max_positive is 1, but found 'max_positive': {max_positive}.") + E1052 = ("Both 'min_length' and 'max_length' should be larger than 1, but found" + " 'min_length': {min_length}, 'max_length': {max_length}") # Deprecated model shortcuts, only used in errors and warnings diff --git a/spacy/pipeline/span_finder.py b/spacy/pipeline/span_finder.py index a59af50e9..988dd4ce9 100644 --- a/spacy/pipeline/span_finder.py +++ b/spacy/pipeline/span_finder.py @@ -9,6 +9,7 @@ from spacy.pipeline.trainable_pipe import TrainablePipe from spacy.scorer import Scorer from spacy.tokens import Doc from spacy.training import Example +from spacy.errors import Errors from ..util import registry from .spancat import DEFAULT_SPANS_KEY, Suggester @@ -52,8 +53,8 @@ DEFAULT_PREDICTED_KEY = "span_candidates" "predicted_key": DEFAULT_PREDICTED_KEY, "training_key": DEFAULT_SPANS_KEY, # XXX Doesn't 0 seem bad compared to None instead? - "max_length": 0, - "min_length": 0, + "max_length": None, + "min_length": None, "scorer": { "@scorers": "spacy.span_finder_scorer.v1", "predicted_key": DEFAULT_PREDICTED_KEY, @@ -72,8 +73,8 @@ def make_span_finder( model: Model[Iterable[Doc], Floats2d], scorer: Optional[Callable], threshold: float, - max_length: int, - min_length: int, + max_length: Optional[int], + min_length: Optional[int], predicted_key: str = DEFAULT_PREDICTED_KEY, training_key: str = DEFAULT_SPANS_KEY, ) -> "SpanFinder": @@ -157,8 +158,8 @@ class SpanFinder(TrainablePipe): name: str = "span_finder", *, threshold: float = 0.5, - max_length: int = 0, - min_length: int = 0, + max_length: Optional[int] = None, + min_length: Optional[int] = None, # XXX I think this is weird and should be just None like in scorer: Optional[Callable] = partial( span_finder_score, @@ -182,6 +183,14 @@ class SpanFinder(TrainablePipe): """ self.vocab = nlp.vocab self.threshold = threshold + if max_length is None: + max_length = float("inf") + if min_length is None: + min_length = 1 + if max_length < 1 or min_length < 1: + raise ValueError( + Errors.E1052.format(min_length=min_length, max_length=max_length) + ) self.max_length = max_length self.min_length = min_length self.predicted_key = predicted_key @@ -227,16 +236,10 @@ class SpanFinder(TrainablePipe): for start in starts: for end in ends: span_length = end + 1 - start - # XXX I really feel like min_length and max_length should be - # None instead of 0 and then just set them to -1 and inf if they - # are given as None. - if span_length > 0: - if ( - self.min_length <= 0 or span_length >= self.min_length - ) and (self.max_length <= 0 or span_length <= self.max_length): - doc.spans[self.predicted_key].append(doc[start : end + 1]) - elif self.max_length > 0 and span_length > self.max_length: - break + if span_length > self.max_length: + break + elif self.min_length <= span_length: + doc.spans[self.predicted_key].append(doc[start : end + 1]) def update( self, diff --git a/spacy/tests/pipeline/test_span_finder.py b/spacy/tests/pipeline/test_span_finder.py index eff863e20..6f6dbaefa 100644 --- a/spacy/tests/pipeline/test_span_finder.py +++ b/spacy/tests/pipeline/test_span_finder.py @@ -106,11 +106,23 @@ def test_span_finder_component(): @pytest.mark.parametrize( - "min_length, max_length, span_count", [(0, 0, 8), (2, 0, 6), (0, 1, 2), (2, 3, 2)] + "min_length, max_length, span_count", + [(0, 0, 0), (None, None, 8), (2, None, 6), (None, 1, 2), (2, 3, 2)] ) def test_set_annotations_span_lengths(min_length, max_length, span_count): nlp = Language() doc = nlp("Me and Jenny goes together like peas and carrots.") + if min_length == 0 and max_length == 0: + with pytest.raises(ValueError, match="Both 'min_length' and 'max_length'"): + span_finder = nlp.add_pipe( + "span_finder", + config={ + "max_length": max_length, + "min_length": min_length, + "training_key": TRAINING_KEY, + }, + ) + return span_finder = nlp.add_pipe( "span_finder", config={ @@ -140,8 +152,10 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count): assert len(doc.spans[DEFAULT_PREDICTED_KEY]) == span_count # Assert below will fail when max_length is set to 0 - if max_length <= 0: - max_length = len(doc) + if max_length is None: + max_length = float("inf") + if min_length is None: + min_length = 1 assert all( min_length <= len(span) <= max_length