max_length and min_length as Optional[int] and strict checking

This commit is contained in:
kadarakos 2023-05-03 11:00:18 +00:00
parent 3b41a988b0
commit 4ef70c094c
3 changed files with 38 additions and 19 deletions

View File

@ -970,6 +970,8 @@ class Errors(metaclass=ErrorsWithCodes):
E1050 = ("Port {port} is already in use. Please specify an available port with `displacy.serve(doc, port=port)` "
"or use `auto_select_port=True` to pick an available port automatically.")
E1051 = ("'allow_overlap' can only be False when max_positive is 1, but found 'max_positive': {max_positive}.")
E1052 = ("Both 'min_length' and 'max_length' should be larger than 1, but found"
" 'min_length': {min_length}, 'max_length': {max_length}")
# Deprecated model shortcuts, only used in errors and warnings

View File

@ -9,6 +9,7 @@ from spacy.pipeline.trainable_pipe import TrainablePipe
from spacy.scorer import Scorer
from spacy.tokens import Doc
from spacy.training import Example
from spacy.errors import Errors
from ..util import registry
from .spancat import DEFAULT_SPANS_KEY, Suggester
@ -52,8 +53,8 @@ DEFAULT_PREDICTED_KEY = "span_candidates"
"predicted_key": DEFAULT_PREDICTED_KEY,
"training_key": DEFAULT_SPANS_KEY,
# XXX Doesn't 0 seem bad compared to None instead?
"max_length": 0,
"min_length": 0,
"max_length": None,
"min_length": None,
"scorer": {
"@scorers": "spacy.span_finder_scorer.v1",
"predicted_key": DEFAULT_PREDICTED_KEY,
@ -72,8 +73,8 @@ def make_span_finder(
model: Model[Iterable[Doc], Floats2d],
scorer: Optional[Callable],
threshold: float,
max_length: int,
min_length: int,
max_length: Optional[int],
min_length: Optional[int],
predicted_key: str = DEFAULT_PREDICTED_KEY,
training_key: str = DEFAULT_SPANS_KEY,
) -> "SpanFinder":
@ -157,8 +158,8 @@ class SpanFinder(TrainablePipe):
name: str = "span_finder",
*,
threshold: float = 0.5,
max_length: int = 0,
min_length: int = 0,
max_length: Optional[int] = None,
min_length: Optional[int] = None,
# XXX I think this is weird and should be just None like in
scorer: Optional[Callable] = partial(
span_finder_score,
@ -182,6 +183,14 @@ class SpanFinder(TrainablePipe):
"""
self.vocab = nlp.vocab
self.threshold = threshold
if max_length is None:
max_length = float("inf")
if min_length is None:
min_length = 1
if max_length < 1 or min_length < 1:
raise ValueError(
Errors.E1052.format(min_length=min_length, max_length=max_length)
)
self.max_length = max_length
self.min_length = min_length
self.predicted_key = predicted_key
@ -227,16 +236,10 @@ class SpanFinder(TrainablePipe):
for start in starts:
for end in ends:
span_length = end + 1 - start
# XXX I really feel like min_length and max_length should be
# None instead of 0 and then just set them to -1 and inf if they
# are given as None.
if span_length > 0:
if (
self.min_length <= 0 or span_length >= self.min_length
) and (self.max_length <= 0 or span_length <= self.max_length):
doc.spans[self.predicted_key].append(doc[start : end + 1])
elif self.max_length > 0 and span_length > self.max_length:
break
if span_length > self.max_length:
break
elif self.min_length <= span_length:
doc.spans[self.predicted_key].append(doc[start : end + 1])
def update(
self,

View File

@ -106,11 +106,23 @@ def test_span_finder_component():
@pytest.mark.parametrize(
"min_length, max_length, span_count", [(0, 0, 8), (2, 0, 6), (0, 1, 2), (2, 3, 2)]
"min_length, max_length, span_count",
[(0, 0, 0), (None, None, 8), (2, None, 6), (None, 1, 2), (2, 3, 2)]
)
def test_set_annotations_span_lengths(min_length, max_length, span_count):
nlp = Language()
doc = nlp("Me and Jenny goes together like peas and carrots.")
if min_length == 0 and max_length == 0:
with pytest.raises(ValueError, match="Both 'min_length' and 'max_length'"):
span_finder = nlp.add_pipe(
"span_finder",
config={
"max_length": max_length,
"min_length": min_length,
"training_key": TRAINING_KEY,
},
)
return
span_finder = nlp.add_pipe(
"span_finder",
config={
@ -140,8 +152,10 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count):
assert len(doc.spans[DEFAULT_PREDICTED_KEY]) == span_count
# Assert below will fail when max_length is set to 0
if max_length <= 0:
max_length = len(doc)
if max_length is None:
max_length = float("inf")
if min_length is None:
min_length = 1
assert all(
min_length <= len(span) <= max_length