mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 04:02:20 +03:00
max_length and min_length as Optional[int] and strict checking
This commit is contained in:
parent
3b41a988b0
commit
4ef70c094c
|
@ -970,6 +970,8 @@ class Errors(metaclass=ErrorsWithCodes):
|
||||||
E1050 = ("Port {port} is already in use. Please specify an available port with `displacy.serve(doc, port=port)` "
|
E1050 = ("Port {port} is already in use. Please specify an available port with `displacy.serve(doc, port=port)` "
|
||||||
"or use `auto_select_port=True` to pick an available port automatically.")
|
"or use `auto_select_port=True` to pick an available port automatically.")
|
||||||
E1051 = ("'allow_overlap' can only be False when max_positive is 1, but found 'max_positive': {max_positive}.")
|
E1051 = ("'allow_overlap' can only be False when max_positive is 1, but found 'max_positive': {max_positive}.")
|
||||||
|
E1052 = ("Both 'min_length' and 'max_length' should be larger than 1, but found"
|
||||||
|
" 'min_length': {min_length}, 'max_length': {max_length}")
|
||||||
|
|
||||||
|
|
||||||
# Deprecated model shortcuts, only used in errors and warnings
|
# Deprecated model shortcuts, only used in errors and warnings
|
||||||
|
|
|
@ -9,6 +9,7 @@ from spacy.pipeline.trainable_pipe import TrainablePipe
|
||||||
from spacy.scorer import Scorer
|
from spacy.scorer import Scorer
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
from spacy.training import Example
|
from spacy.training import Example
|
||||||
|
from spacy.errors import Errors
|
||||||
|
|
||||||
from ..util import registry
|
from ..util import registry
|
||||||
from .spancat import DEFAULT_SPANS_KEY, Suggester
|
from .spancat import DEFAULT_SPANS_KEY, Suggester
|
||||||
|
@ -52,8 +53,8 @@ DEFAULT_PREDICTED_KEY = "span_candidates"
|
||||||
"predicted_key": DEFAULT_PREDICTED_KEY,
|
"predicted_key": DEFAULT_PREDICTED_KEY,
|
||||||
"training_key": DEFAULT_SPANS_KEY,
|
"training_key": DEFAULT_SPANS_KEY,
|
||||||
# XXX Doesn't 0 seem bad compared to None instead?
|
# XXX Doesn't 0 seem bad compared to None instead?
|
||||||
"max_length": 0,
|
"max_length": None,
|
||||||
"min_length": 0,
|
"min_length": None,
|
||||||
"scorer": {
|
"scorer": {
|
||||||
"@scorers": "spacy.span_finder_scorer.v1",
|
"@scorers": "spacy.span_finder_scorer.v1",
|
||||||
"predicted_key": DEFAULT_PREDICTED_KEY,
|
"predicted_key": DEFAULT_PREDICTED_KEY,
|
||||||
|
@ -72,8 +73,8 @@ def make_span_finder(
|
||||||
model: Model[Iterable[Doc], Floats2d],
|
model: Model[Iterable[Doc], Floats2d],
|
||||||
scorer: Optional[Callable],
|
scorer: Optional[Callable],
|
||||||
threshold: float,
|
threshold: float,
|
||||||
max_length: int,
|
max_length: Optional[int],
|
||||||
min_length: int,
|
min_length: Optional[int],
|
||||||
predicted_key: str = DEFAULT_PREDICTED_KEY,
|
predicted_key: str = DEFAULT_PREDICTED_KEY,
|
||||||
training_key: str = DEFAULT_SPANS_KEY,
|
training_key: str = DEFAULT_SPANS_KEY,
|
||||||
) -> "SpanFinder":
|
) -> "SpanFinder":
|
||||||
|
@ -157,8 +158,8 @@ class SpanFinder(TrainablePipe):
|
||||||
name: str = "span_finder",
|
name: str = "span_finder",
|
||||||
*,
|
*,
|
||||||
threshold: float = 0.5,
|
threshold: float = 0.5,
|
||||||
max_length: int = 0,
|
max_length: Optional[int] = None,
|
||||||
min_length: int = 0,
|
min_length: Optional[int] = None,
|
||||||
# XXX I think this is weird and should be just None like in
|
# XXX I think this is weird and should be just None like in
|
||||||
scorer: Optional[Callable] = partial(
|
scorer: Optional[Callable] = partial(
|
||||||
span_finder_score,
|
span_finder_score,
|
||||||
|
@ -182,6 +183,14 @@ class SpanFinder(TrainablePipe):
|
||||||
"""
|
"""
|
||||||
self.vocab = nlp.vocab
|
self.vocab = nlp.vocab
|
||||||
self.threshold = threshold
|
self.threshold = threshold
|
||||||
|
if max_length is None:
|
||||||
|
max_length = float("inf")
|
||||||
|
if min_length is None:
|
||||||
|
min_length = 1
|
||||||
|
if max_length < 1 or min_length < 1:
|
||||||
|
raise ValueError(
|
||||||
|
Errors.E1052.format(min_length=min_length, max_length=max_length)
|
||||||
|
)
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.min_length = min_length
|
self.min_length = min_length
|
||||||
self.predicted_key = predicted_key
|
self.predicted_key = predicted_key
|
||||||
|
@ -227,16 +236,10 @@ class SpanFinder(TrainablePipe):
|
||||||
for start in starts:
|
for start in starts:
|
||||||
for end in ends:
|
for end in ends:
|
||||||
span_length = end + 1 - start
|
span_length = end + 1 - start
|
||||||
# XXX I really feel like min_length and max_length should be
|
if span_length > self.max_length:
|
||||||
# None instead of 0 and then just set them to -1 and inf if they
|
|
||||||
# are given as None.
|
|
||||||
if span_length > 0:
|
|
||||||
if (
|
|
||||||
self.min_length <= 0 or span_length >= self.min_length
|
|
||||||
) and (self.max_length <= 0 or span_length <= self.max_length):
|
|
||||||
doc.spans[self.predicted_key].append(doc[start : end + 1])
|
|
||||||
elif self.max_length > 0 and span_length > self.max_length:
|
|
||||||
break
|
break
|
||||||
|
elif self.min_length <= span_length:
|
||||||
|
doc.spans[self.predicted_key].append(doc[start : end + 1])
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -106,11 +106,23 @@ def test_span_finder_component():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"min_length, max_length, span_count", [(0, 0, 8), (2, 0, 6), (0, 1, 2), (2, 3, 2)]
|
"min_length, max_length, span_count",
|
||||||
|
[(0, 0, 0), (None, None, 8), (2, None, 6), (None, 1, 2), (2, 3, 2)]
|
||||||
)
|
)
|
||||||
def test_set_annotations_span_lengths(min_length, max_length, span_count):
|
def test_set_annotations_span_lengths(min_length, max_length, span_count):
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
doc = nlp("Me and Jenny goes together like peas and carrots.")
|
doc = nlp("Me and Jenny goes together like peas and carrots.")
|
||||||
|
if min_length == 0 and max_length == 0:
|
||||||
|
with pytest.raises(ValueError, match="Both 'min_length' and 'max_length'"):
|
||||||
|
span_finder = nlp.add_pipe(
|
||||||
|
"span_finder",
|
||||||
|
config={
|
||||||
|
"max_length": max_length,
|
||||||
|
"min_length": min_length,
|
||||||
|
"training_key": TRAINING_KEY,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return
|
||||||
span_finder = nlp.add_pipe(
|
span_finder = nlp.add_pipe(
|
||||||
"span_finder",
|
"span_finder",
|
||||||
config={
|
config={
|
||||||
|
@ -140,8 +152,10 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count):
|
||||||
assert len(doc.spans[DEFAULT_PREDICTED_KEY]) == span_count
|
assert len(doc.spans[DEFAULT_PREDICTED_KEY]) == span_count
|
||||||
|
|
||||||
# Assert below will fail when max_length is set to 0
|
# Assert below will fail when max_length is set to 0
|
||||||
if max_length <= 0:
|
if max_length is None:
|
||||||
max_length = len(doc)
|
max_length = float("inf")
|
||||||
|
if min_length is None:
|
||||||
|
min_length = 1
|
||||||
|
|
||||||
assert all(
|
assert all(
|
||||||
min_length <= len(span) <= max_length
|
min_length <= len(span) <= max_length
|
||||||
|
|
Loading…
Reference in New Issue
Block a user