diff --git a/spacy/errors.py b/spacy/errors.py index eadbf63d6..9ddf4eaaa 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -923,7 +923,7 @@ class Errors(metaclass=ErrorsWithCodes): E1029 = ("Edit tree cannot be applied to form.") E1030 = ("Edit tree identifier out of range.") E1031 = ("Could not find gold transition - see logs above.") - E1032 = ("`{var}` should not be {forbidden}, but received {value}.") + E1032 = ("Span {var} {value} is out of bounds for {obj} with length {length}.") E1033 = ("Dimension {name} invalid -- only nO, nF, nP") E1034 = ("Node index {i} out of bounds ({length})") E1035 = ("Token index {i} out of bounds ({length})") @@ -961,6 +961,9 @@ class Errors(metaclass=ErrorsWithCodes): E4003 = ("Training examples for distillation must have the exact same tokens in the " "reference and predicted docs.") E4004 = ("Backprop is not supported when is_train is not set.") + E4005 = ("Span {var} {value} must be {op} Span {existing_var} " + "{existing_value}.") + E4006 = ("Span {pos}_char {value} does not correspond to a token {pos}.") RENAMED_LANGUAGE_CODES = {"xx": "mul", "is": "isl"} diff --git a/spacy/tests/doc/test_span.py b/spacy/tests/doc/test_span.py index a99f8b561..8eabcd645 100644 --- a/spacy/tests/doc/test_span.py +++ b/spacy/tests/doc/test_span.py @@ -707,3 +707,50 @@ def test_span_ent_id(en_tokenizer): doc.ents = [span] assert doc.ents[0].ent_id_ == "ID2" assert doc[1].ent_id_ == "ID2" + + +def test_span_start_end_sync(en_tokenizer): + doc = en_tokenizer("a bc def e fghij kl") + # can create and edit span starts/ends + span = doc[2:4] + span.start_char = 2 + span.end = 5 + assert span == doc[span.start : span.end] + assert span == doc.char_span(span.start_char, span.end_char) + # cannot set completely out of bounds starts/ends + with pytest.raises(IndexError): + span.start = -1 + with pytest.raises(IndexError): + span.end = -1 + with pytest.raises(IndexError): + span.start_char = len(doc.text) + 1 + with pytest.raises(IndexError): + span.end = len(doc.text) + 1 + # test all possible char starts/ends + span = doc[0 : len(doc)] + token_char_starts = [token.idx for token in doc] + token_char_ends = [token.idx + len(token.text) for token in doc] + for i in range(len(doc.text)): + if i not in token_char_starts: + with pytest.raises(ValueError): + span.start_char = i + else: + span.start_char = i + span = doc[0 : len(doc)] + for i in range(len(doc.text)): + if i not in token_char_ends: + with pytest.raises(ValueError): + span.end_char = i + else: + span.end_char = i + # start must be <= end + span = doc[1:3] + with pytest.raises(ValueError): + span.start = 4 + with pytest.raises(ValueError): + span.end = 0 + span = doc.char_span(2, 8) + with pytest.raises(ValueError): + span.start_char = 9 + with pytest.raises(ValueError): + span.end_char = 1 diff --git a/spacy/tokens/span.pyx b/spacy/tokens/span.pyx index 4990cb5f7..0d59c5af2 100644 --- a/spacy/tokens/span.pyx +++ b/spacy/tokens/span.pyx @@ -770,36 +770,61 @@ cdef class Span: return self.span_c().start def __set__(self, int start): - if start < 0: - raise IndexError(Errors.E1032.format(var="start", forbidden="< 0", value=start)) - self.span_c().start = start + if start < 0 or start > self.doc.length: + raise IndexError(Errors.E1032.format(var="start", obj="Doc", length=self.doc.length, value=start)) + cdef SpanC* span_c = self.span_c() + if start > span_c.end: + raise ValueError(Errors.E4005.format(var="start", value=start, op="<=", existing_var="end", existing_value=span_c.end)) + span_c.start = start + span_c.start_char = self.doc.c[start].idx property end: def __get__(self): return self.span_c().end def __set__(self, int end): - if end < 0: - raise IndexError(Errors.E1032.format(var="end", forbidden="< 0", value=end)) - self.span_c().end = end + if end < 0 or end > self.doc.length: + raise IndexError(Errors.E1032.format(var="end", obj="Doc", length=self.doc.length, value=end)) + cdef SpanC* span_c = self.span_c() + if span_c.start > end: + raise ValueError(Errors.E4005.format(var="end", value=end, op=">=", existing_var="start", existing_value=span_c.start)) + span_c.end = end + if end > 0: + span_c.end_char = self.doc.c[end-1].idx + self.doc.c[end-1].lex.length + else: + span_c.end_char = 0 property start_char: def __get__(self): return self.span_c().start_char def __set__(self, int start_char): - if start_char < 0: - raise IndexError(Errors.E1032.format(var="start_char", forbidden="< 0", value=start_char)) - self.span_c().start_char = start_char + if start_char < 0 or start_char > len(self.doc.text): + raise IndexError(Errors.E1032.format(var="start_char", obj="Doc text", length=len(self.doc.text), value=start_char)) + cdef int start = token_by_start(self.doc.c, self.doc.length, start_char) + if start < 0: + raise ValueError(Errors.E4006.format(value=start_char, pos="start")) + cdef SpanC* span_c = self.span_c() + if start_char > span_c.end_char: + raise ValueError(Errors.E4005.format(var="start_char", value=start_char, op="<=", existing_var="end_char", existing_value=span_c.end_char)) + span_c.start_char = start_char + span_c.start = start property end_char: def __get__(self): return self.span_c().end_char def __set__(self, int end_char): - if end_char < 0: - raise IndexError(Errors.E1032.format(var="end_char", forbidden="< 0", value=end_char)) - self.span_c().end_char = end_char + if end_char < 0 or end_char > len(self.doc.text): + raise IndexError(Errors.E1032.format(var="end_char", obj="Doc text", length=len(self.doc.text), value=end_char)) + cdef int end = token_by_end(self.doc.c, self.doc.length, end_char) + if end < 0: + raise ValueError(Errors.E4006.format(value=end_char, pos="end")) + cdef SpanC* span_c = self.span_c() + if span_c.start_char > end_char: + raise ValueError(Errors.E4005.format(var="end_char", value=end_char, op=">=", existing_var="start_char", existing_value=span_c.start_char)) + span_c.end_char = end_char + span_c.end = end property label: def __get__(self):