max_length and min_length as Optional[int] and strict checking

This commit is contained in:
kadarakos 2023-05-03 11:00:18 +00:00
parent 3b41a988b0
commit 4ef70c094c
3 changed files with 38 additions and 19 deletions

View File

@ -970,6 +970,8 @@ class Errors(metaclass=ErrorsWithCodes):
E1050 = ("Port {port} is already in use. Please specify an available port with `displacy.serve(doc, port=port)` " E1050 = ("Port {port} is already in use. Please specify an available port with `displacy.serve(doc, port=port)` "
"or use `auto_select_port=True` to pick an available port automatically.") "or use `auto_select_port=True` to pick an available port automatically.")
E1051 = ("'allow_overlap' can only be False when max_positive is 1, but found 'max_positive': {max_positive}.") E1051 = ("'allow_overlap' can only be False when max_positive is 1, but found 'max_positive': {max_positive}.")
E1052 = ("Both 'min_length' and 'max_length' should be larger than 1, but found"
" 'min_length': {min_length}, 'max_length': {max_length}")
# Deprecated model shortcuts, only used in errors and warnings # Deprecated model shortcuts, only used in errors and warnings

View File

@ -9,6 +9,7 @@ from spacy.pipeline.trainable_pipe import TrainablePipe
from spacy.scorer import Scorer from spacy.scorer import Scorer
from spacy.tokens import Doc from spacy.tokens import Doc
from spacy.training import Example from spacy.training import Example
from spacy.errors import Errors
from ..util import registry from ..util import registry
from .spancat import DEFAULT_SPANS_KEY, Suggester from .spancat import DEFAULT_SPANS_KEY, Suggester
@ -52,8 +53,8 @@ DEFAULT_PREDICTED_KEY = "span_candidates"
"predicted_key": DEFAULT_PREDICTED_KEY, "predicted_key": DEFAULT_PREDICTED_KEY,
"training_key": DEFAULT_SPANS_KEY, "training_key": DEFAULT_SPANS_KEY,
# XXX Doesn't 0 seem bad compared to None instead? # XXX Doesn't 0 seem bad compared to None instead?
"max_length": 0, "max_length": None,
"min_length": 0, "min_length": None,
"scorer": { "scorer": {
"@scorers": "spacy.span_finder_scorer.v1", "@scorers": "spacy.span_finder_scorer.v1",
"predicted_key": DEFAULT_PREDICTED_KEY, "predicted_key": DEFAULT_PREDICTED_KEY,
@ -72,8 +73,8 @@ def make_span_finder(
model: Model[Iterable[Doc], Floats2d], model: Model[Iterable[Doc], Floats2d],
scorer: Optional[Callable], scorer: Optional[Callable],
threshold: float, threshold: float,
max_length: int, max_length: Optional[int],
min_length: int, min_length: Optional[int],
predicted_key: str = DEFAULT_PREDICTED_KEY, predicted_key: str = DEFAULT_PREDICTED_KEY,
training_key: str = DEFAULT_SPANS_KEY, training_key: str = DEFAULT_SPANS_KEY,
) -> "SpanFinder": ) -> "SpanFinder":
@ -157,8 +158,8 @@ class SpanFinder(TrainablePipe):
name: str = "span_finder", name: str = "span_finder",
*, *,
threshold: float = 0.5, threshold: float = 0.5,
max_length: int = 0, max_length: Optional[int] = None,
min_length: int = 0, min_length: Optional[int] = None,
# XXX I think this is weird and should be just None like in # XXX I think this is weird and should be just None like in
scorer: Optional[Callable] = partial( scorer: Optional[Callable] = partial(
span_finder_score, span_finder_score,
@ -182,6 +183,14 @@ class SpanFinder(TrainablePipe):
""" """
self.vocab = nlp.vocab self.vocab = nlp.vocab
self.threshold = threshold self.threshold = threshold
if max_length is None:
max_length = float("inf")
if min_length is None:
min_length = 1
if max_length < 1 or min_length < 1:
raise ValueError(
Errors.E1052.format(min_length=min_length, max_length=max_length)
)
self.max_length = max_length self.max_length = max_length
self.min_length = min_length self.min_length = min_length
self.predicted_key = predicted_key self.predicted_key = predicted_key
@ -227,16 +236,10 @@ class SpanFinder(TrainablePipe):
for start in starts: for start in starts:
for end in ends: for end in ends:
span_length = end + 1 - start span_length = end + 1 - start
# XXX I really feel like min_length and max_length should be if span_length > self.max_length:
# None instead of 0 and then just set them to -1 and inf if they
# are given as None.
if span_length > 0:
if (
self.min_length <= 0 or span_length >= self.min_length
) and (self.max_length <= 0 or span_length <= self.max_length):
doc.spans[self.predicted_key].append(doc[start : end + 1])
elif self.max_length > 0 and span_length > self.max_length:
break break
elif self.min_length <= span_length:
doc.spans[self.predicted_key].append(doc[start : end + 1])
def update( def update(
self, self,

View File

@ -106,11 +106,23 @@ def test_span_finder_component():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"min_length, max_length, span_count", [(0, 0, 8), (2, 0, 6), (0, 1, 2), (2, 3, 2)] "min_length, max_length, span_count",
[(0, 0, 0), (None, None, 8), (2, None, 6), (None, 1, 2), (2, 3, 2)]
) )
def test_set_annotations_span_lengths(min_length, max_length, span_count): def test_set_annotations_span_lengths(min_length, max_length, span_count):
nlp = Language() nlp = Language()
doc = nlp("Me and Jenny goes together like peas and carrots.") doc = nlp("Me and Jenny goes together like peas and carrots.")
if min_length == 0 and max_length == 0:
with pytest.raises(ValueError, match="Both 'min_length' and 'max_length'"):
span_finder = nlp.add_pipe(
"span_finder",
config={
"max_length": max_length,
"min_length": min_length,
"training_key": TRAINING_KEY,
},
)
return
span_finder = nlp.add_pipe( span_finder = nlp.add_pipe(
"span_finder", "span_finder",
config={ config={
@ -140,8 +152,10 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count):
assert len(doc.spans[DEFAULT_PREDICTED_KEY]) == span_count assert len(doc.spans[DEFAULT_PREDICTED_KEY]) == span_count
# Assert below will fail when max_length is set to 0 # Assert below will fail when max_length is set to 0
if max_length <= 0: if max_length is None:
max_length = len(doc) max_length = float("inf")
if min_length is None:
min_length = 1
assert all( assert all(
min_length <= len(span) <= max_length min_length <= len(span) <= max_length