From 32b3ca3d887fdf186918203ab587de5d20fa2a07 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Thu, 9 Feb 2023 22:04:59 +0100 Subject: [PATCH] Enforce that Span.start/end(_char) remain valid and in sync Allowing span attributes to be writable starting in v3 has made it possible for the internal `Span.start/end/start_char/end_char` to get out-of-sync or have invalid values. This checks that the values are valid and syncs the token and char offsets if any attributes are modified directly. It does not yet handle the case where the underlying doc is modified. --- spacy/errors.py | 5 +++- spacy/tests/doc/test_span.py | 47 ++++++++++++++++++++++++++++++++++ spacy/tokens/span.pyx | 49 +++++++++++++++++++++++++++--------- 3 files changed, 88 insertions(+), 13 deletions(-) diff --git a/spacy/errors.py b/spacy/errors.py index eadbf63d6..9ddf4eaaa 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -923,7 +923,7 @@ class Errors(metaclass=ErrorsWithCodes): E1029 = ("Edit tree cannot be applied to form.") E1030 = ("Edit tree identifier out of range.") E1031 = ("Could not find gold transition - see logs above.") - E1032 = ("`{var}` should not be {forbidden}, but received {value}.") + E1032 = ("Span {var} {value} is out of bounds for {obj} with length {length}.") E1033 = ("Dimension {name} invalid -- only nO, nF, nP") E1034 = ("Node index {i} out of bounds ({length})") E1035 = ("Token index {i} out of bounds ({length})") @@ -961,6 +961,9 @@ class Errors(metaclass=ErrorsWithCodes): E4003 = ("Training examples for distillation must have the exact same tokens in the " "reference and predicted docs.") E4004 = ("Backprop is not supported when is_train is not set.") + E4005 = ("Span {var} {value} must be {op} Span {existing_var} " + "{existing_value}.") + E4006 = ("Span {pos}_char {value} does not correspond to a token {pos}.") RENAMED_LANGUAGE_CODES = {"xx": "mul", "is": "isl"} diff --git a/spacy/tests/doc/test_span.py b/spacy/tests/doc/test_span.py index a99f8b561..8eabcd645 100644 --- a/spacy/tests/doc/test_span.py +++ b/spacy/tests/doc/test_span.py @@ -707,3 +707,50 @@ def test_span_ent_id(en_tokenizer): doc.ents = [span] assert doc.ents[0].ent_id_ == "ID2" assert doc[1].ent_id_ == "ID2" + + +def test_span_start_end_sync(en_tokenizer): + doc = en_tokenizer("a bc def e fghij kl") + # can create and edit span starts/ends + span = doc[2:4] + span.start_char = 2 + span.end = 5 + assert span == doc[span.start : span.end] + assert span == doc.char_span(span.start_char, span.end_char) + # cannot set completely out of bounds starts/ends + with pytest.raises(IndexError): + span.start = -1 + with pytest.raises(IndexError): + span.end = -1 + with pytest.raises(IndexError): + span.start_char = len(doc.text) + 1 + with pytest.raises(IndexError): + span.end = len(doc.text) + 1 + # test all possible char starts/ends + span = doc[0 : len(doc)] + token_char_starts = [token.idx for token in doc] + token_char_ends = [token.idx + len(token.text) for token in doc] + for i in range(len(doc.text)): + if i not in token_char_starts: + with pytest.raises(ValueError): + span.start_char = i + else: + span.start_char = i + span = doc[0 : len(doc)] + for i in range(len(doc.text)): + if i not in token_char_ends: + with pytest.raises(ValueError): + span.end_char = i + else: + span.end_char = i + # start must be <= end + span = doc[1:3] + with pytest.raises(ValueError): + span.start = 4 + with pytest.raises(ValueError): + span.end = 0 + span = doc.char_span(2, 8) + with pytest.raises(ValueError): + span.start_char = 9 + with pytest.raises(ValueError): + span.end_char = 1 diff --git a/spacy/tokens/span.pyx b/spacy/tokens/span.pyx index 4990cb5f7..0d59c5af2 100644 --- a/spacy/tokens/span.pyx +++ b/spacy/tokens/span.pyx @@ -770,36 +770,61 @@ cdef class Span: return self.span_c().start def __set__(self, int start): - if start < 0: - raise IndexError(Errors.E1032.format(var="start", forbidden="< 0", value=start)) - self.span_c().start = start + if start < 0 or start > self.doc.length: + raise IndexError(Errors.E1032.format(var="start", obj="Doc", length=self.doc.length, value=start)) + cdef SpanC* span_c = self.span_c() + if start > span_c.end: + raise ValueError(Errors.E4005.format(var="start", value=start, op="<=", existing_var="end", existing_value=span_c.end)) + span_c.start = start + span_c.start_char = self.doc.c[start].idx property end: def __get__(self): return self.span_c().end def __set__(self, int end): - if end < 0: - raise IndexError(Errors.E1032.format(var="end", forbidden="< 0", value=end)) - self.span_c().end = end + if end < 0 or end > self.doc.length: + raise IndexError(Errors.E1032.format(var="end", obj="Doc", length=self.doc.length, value=end)) + cdef SpanC* span_c = self.span_c() + if span_c.start > end: + raise ValueError(Errors.E4005.format(var="end", value=end, op=">=", existing_var="start", existing_value=span_c.start)) + span_c.end = end + if end > 0: + span_c.end_char = self.doc.c[end-1].idx + self.doc.c[end-1].lex.length + else: + span_c.end_char = 0 property start_char: def __get__(self): return self.span_c().start_char def __set__(self, int start_char): - if start_char < 0: - raise IndexError(Errors.E1032.format(var="start_char", forbidden="< 0", value=start_char)) - self.span_c().start_char = start_char + if start_char < 0 or start_char > len(self.doc.text): + raise IndexError(Errors.E1032.format(var="start_char", obj="Doc text", length=len(self.doc.text), value=start_char)) + cdef int start = token_by_start(self.doc.c, self.doc.length, start_char) + if start < 0: + raise ValueError(Errors.E4006.format(value=start_char, pos="start")) + cdef SpanC* span_c = self.span_c() + if start_char > span_c.end_char: + raise ValueError(Errors.E4005.format(var="start_char", value=start_char, op="<=", existing_var="end_char", existing_value=span_c.end_char)) + span_c.start_char = start_char + span_c.start = start property end_char: def __get__(self): return self.span_c().end_char def __set__(self, int end_char): - if end_char < 0: - raise IndexError(Errors.E1032.format(var="end_char", forbidden="< 0", value=end_char)) - self.span_c().end_char = end_char + if end_char < 0 or end_char > len(self.doc.text): + raise IndexError(Errors.E1032.format(var="end_char", obj="Doc text", length=len(self.doc.text), value=end_char)) + cdef int end = token_by_end(self.doc.c, self.doc.length, end_char) + if end < 0: + raise ValueError(Errors.E4006.format(value=end_char, pos="end")) + cdef SpanC* span_c = self.span_c() + if span_c.start_char > end_char: + raise ValueError(Errors.E4005.format(var="end_char", value=end_char, op=">=", existing_var="start_char", existing_value=span_c.start_char)) + span_c.end_char = end_char + span_c.end = end property label: def __get__(self):