mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-16 19:22:34 +03:00
max_length and min_length as Optional[int] and strict checking
This commit is contained in:
parent
3b41a988b0
commit
4ef70c094c
|
@ -970,6 +970,8 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
E1050 = ("Port {port} is already in use. Please specify an available port with `displacy.serve(doc, port=port)` "
|
||||
"or use `auto_select_port=True` to pick an available port automatically.")
|
||||
E1051 = ("'allow_overlap' can only be False when max_positive is 1, but found 'max_positive': {max_positive}.")
|
||||
E1052 = ("Both 'min_length' and 'max_length' should be larger than 1, but found"
|
||||
" 'min_length': {min_length}, 'max_length': {max_length}")
|
||||
|
||||
|
||||
# Deprecated model shortcuts, only used in errors and warnings
|
||||
|
|
|
@ -9,6 +9,7 @@ from spacy.pipeline.trainable_pipe import TrainablePipe
|
|||
from spacy.scorer import Scorer
|
||||
from spacy.tokens import Doc
|
||||
from spacy.training import Example
|
||||
from spacy.errors import Errors
|
||||
|
||||
from ..util import registry
|
||||
from .spancat import DEFAULT_SPANS_KEY, Suggester
|
||||
|
@ -52,8 +53,8 @@ DEFAULT_PREDICTED_KEY = "span_candidates"
|
|||
"predicted_key": DEFAULT_PREDICTED_KEY,
|
||||
"training_key": DEFAULT_SPANS_KEY,
|
||||
# XXX Doesn't 0 seem bad compared to None instead?
|
||||
"max_length": 0,
|
||||
"min_length": 0,
|
||||
"max_length": None,
|
||||
"min_length": None,
|
||||
"scorer": {
|
||||
"@scorers": "spacy.span_finder_scorer.v1",
|
||||
"predicted_key": DEFAULT_PREDICTED_KEY,
|
||||
|
@ -72,8 +73,8 @@ def make_span_finder(
|
|||
model: Model[Iterable[Doc], Floats2d],
|
||||
scorer: Optional[Callable],
|
||||
threshold: float,
|
||||
max_length: int,
|
||||
min_length: int,
|
||||
max_length: Optional[int],
|
||||
min_length: Optional[int],
|
||||
predicted_key: str = DEFAULT_PREDICTED_KEY,
|
||||
training_key: str = DEFAULT_SPANS_KEY,
|
||||
) -> "SpanFinder":
|
||||
|
@ -157,8 +158,8 @@ class SpanFinder(TrainablePipe):
|
|||
name: str = "span_finder",
|
||||
*,
|
||||
threshold: float = 0.5,
|
||||
max_length: int = 0,
|
||||
min_length: int = 0,
|
||||
max_length: Optional[int] = None,
|
||||
min_length: Optional[int] = None,
|
||||
# XXX I think this is weird and should be just None like in
|
||||
scorer: Optional[Callable] = partial(
|
||||
span_finder_score,
|
||||
|
@ -182,6 +183,14 @@ class SpanFinder(TrainablePipe):
|
|||
"""
|
||||
self.vocab = nlp.vocab
|
||||
self.threshold = threshold
|
||||
if max_length is None:
|
||||
max_length = float("inf")
|
||||
if min_length is None:
|
||||
min_length = 1
|
||||
if max_length < 1 or min_length < 1:
|
||||
raise ValueError(
|
||||
Errors.E1052.format(min_length=min_length, max_length=max_length)
|
||||
)
|
||||
self.max_length = max_length
|
||||
self.min_length = min_length
|
||||
self.predicted_key = predicted_key
|
||||
|
@ -227,16 +236,10 @@ class SpanFinder(TrainablePipe):
|
|||
for start in starts:
|
||||
for end in ends:
|
||||
span_length = end + 1 - start
|
||||
# XXX I really feel like min_length and max_length should be
|
||||
# None instead of 0 and then just set them to -1 and inf if they
|
||||
# are given as None.
|
||||
if span_length > 0:
|
||||
if (
|
||||
self.min_length <= 0 or span_length >= self.min_length
|
||||
) and (self.max_length <= 0 or span_length <= self.max_length):
|
||||
doc.spans[self.predicted_key].append(doc[start : end + 1])
|
||||
elif self.max_length > 0 and span_length > self.max_length:
|
||||
break
|
||||
if span_length > self.max_length:
|
||||
break
|
||||
elif self.min_length <= span_length:
|
||||
doc.spans[self.predicted_key].append(doc[start : end + 1])
|
||||
|
||||
def update(
|
||||
self,
|
||||
|
|
|
@ -106,11 +106,23 @@ def test_span_finder_component():
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"min_length, max_length, span_count", [(0, 0, 8), (2, 0, 6), (0, 1, 2), (2, 3, 2)]
|
||||
"min_length, max_length, span_count",
|
||||
[(0, 0, 0), (None, None, 8), (2, None, 6), (None, 1, 2), (2, 3, 2)]
|
||||
)
|
||||
def test_set_annotations_span_lengths(min_length, max_length, span_count):
|
||||
nlp = Language()
|
||||
doc = nlp("Me and Jenny goes together like peas and carrots.")
|
||||
if min_length == 0 and max_length == 0:
|
||||
with pytest.raises(ValueError, match="Both 'min_length' and 'max_length'"):
|
||||
span_finder = nlp.add_pipe(
|
||||
"span_finder",
|
||||
config={
|
||||
"max_length": max_length,
|
||||
"min_length": min_length,
|
||||
"training_key": TRAINING_KEY,
|
||||
},
|
||||
)
|
||||
return
|
||||
span_finder = nlp.add_pipe(
|
||||
"span_finder",
|
||||
config={
|
||||
|
@ -140,8 +152,10 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count):
|
|||
assert len(doc.spans[DEFAULT_PREDICTED_KEY]) == span_count
|
||||
|
||||
# Assert below will fail when max_length is set to 0
|
||||
if max_length <= 0:
|
||||
max_length = len(doc)
|
||||
if max_length is None:
|
||||
max_length = float("inf")
|
||||
if min_length is None:
|
||||
min_length = 1
|
||||
|
||||
assert all(
|
||||
min_length <= len(span) <= max_length
|
||||
|
|
Loading…
Reference in New Issue
Block a user