diff --git a/spacy/tokens/spans.pyx b/spacy/tokens/spans.pyx index c39f8976c..99efad4b9 100644 --- a/spacy/tokens/spans.pyx +++ b/spacy/tokens/spans.pyx @@ -16,9 +16,13 @@ cdef class Span: def __cinit__(self, Doc tokens, int start, int end, int label=0, vector=None, vector_norm=None): if start < 0: - start = tokens.length - start + start = tokens.length + start + start = min(tokens.length, max(0, start)) + if end < 0: - end = tokens.length - end + end = tokens.length + end + end = min(tokens.length, max(start, end)) + self.doc = tokens self.start = start self.end = end diff --git a/tests/tokens/test_tokens_api.py b/tests/tokens/test_tokens_api.py index fc1b52143..a272a8e3b 100644 --- a/tests/tokens/test_tokens_api.py +++ b/tests/tokens/test_tokens_api.py @@ -12,17 +12,51 @@ def test_getitem(EN): with pytest.raises(IndexError): tokens[len(tokens)] + def to_str(span): + return '/'.join(token.orth_ for token in span) + span = tokens[1:1] - assert not '/'.join(token.orth_ for token in span) + assert not to_str(span) span = tokens[1:4] - assert '/'.join(token.orth_ for token in span) == 'it/back/!' + assert to_str(span) == 'it/back/!' span = tokens[1:4:1] - assert '/'.join(token.orth_ for token in span) == 'it/back/!' + assert to_str(span) == 'it/back/!' with pytest.raises(ValueError): tokens[1:4:2] with pytest.raises(ValueError): tokens[1:4:-1] + span = tokens[-3:6] + assert to_str(span) == 'He/pleaded' + span = tokens[4:-1] + assert to_str(span) == 'He/pleaded' + span = tokens[-5:-3] + assert to_str(span) == 'back/!' + span = tokens[5:4] + assert span.start == span.end == 5 and not to_str(span) + span = tokens[4:-3] + assert span.start == span.end == 4 and not to_str(span) + + span = tokens[:] + assert to_str(span) == 'Give/it/back/!/He/pleaded/.' + span = tokens[4:] + assert to_str(span) == 'He/pleaded/.' + span = tokens[:4] + assert to_str(span) == 'Give/it/back/!' + span = tokens[:-3] + assert to_str(span) == 'Give/it/back/!' + span = tokens[-3:] + assert to_str(span) == 'He/pleaded/.' + + span = tokens[4:50] + assert to_str(span) == 'He/pleaded/.' + span = tokens[-50:4] + assert to_str(span) == 'Give/it/back/!' + span = tokens[-50:-40] + assert span.start == span.end == 0 and not to_str(span) + span = tokens[40:50] + assert span.start == span.end == 7 and not to_str(span) + @pytest.mark.models def test_serialize(EN):