Enforce that Span.start/end(_char) remain valid and in sync (#12268)

* Enforce that Span.start/end(_char) remain valid and in sync

Allowing span attributes to be writable starting in v3 has made it
possible for the internal `Span.start/end/start_char/end_char` to get
out-of-sync or have invalid values.

This checks that the values are valid and syncs the token and char
offsets if any attributes are modified directly. It does not yet handle
the case where the underlying doc is modified.

* Format
This commit is contained in:
Adriane Boyd 2023-04-06 16:01:59 +02:00 committed by GitHub
parent b734e5314d
commit 5d0f48fe69
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 88 additions and 13 deletions

View File

@ -926,7 +926,7 @@ class Errors(metaclass=ErrorsWithCodes):
E1029 = ("Edit tree cannot be applied to form.")
E1030 = ("Edit tree identifier out of range.")
E1031 = ("Could not find gold transition - see logs above.")
E1032 = ("`{var}` should not be {forbidden}, but received {value}.")
E1032 = ("Span {var} {value} is out of bounds for {obj} with length {length}.")
E1033 = ("Dimension {name} invalid -- only nO, nF, nP")
E1034 = ("Node index {i} out of bounds ({length})")
E1035 = ("Token index {i} out of bounds ({length})")
@ -966,6 +966,9 @@ class Errors(metaclass=ErrorsWithCodes):
E4004 = ("Backprop is not supported when is_train is not set.")
E4005 = ("EntityLinker_v1 is not supported in spaCy v4. Update your configuration.")
E4006 = ("Expected `entity_id` to be of type {exp_type}, but is of type {found_type}.")
E4007 = ("Span {var} {value} must be {op} Span {existing_var} "
"{existing_value}.")
E4008 = ("Span {pos}_char {value} does not correspond to a token {pos}.")
RENAMED_LANGUAGE_CODES = {"xx": "mul", "is": "isl"}

View File

@ -707,3 +707,50 @@ def test_span_ent_id(en_tokenizer):
doc.ents = [span]
assert doc.ents[0].ent_id_ == "ID2"
assert doc[1].ent_id_ == "ID2"
def test_span_start_end_sync(en_tokenizer):
doc = en_tokenizer("a bc def e fghij kl")
# can create and edit span starts/ends
span = doc[2:4]
span.start_char = 2
span.end = 5
assert span == doc[span.start : span.end]
assert span == doc.char_span(span.start_char, span.end_char)
# cannot set completely out of bounds starts/ends
with pytest.raises(IndexError):
span.start = -1
with pytest.raises(IndexError):
span.end = -1
with pytest.raises(IndexError):
span.start_char = len(doc.text) + 1
with pytest.raises(IndexError):
span.end = len(doc.text) + 1
# test all possible char starts/ends
span = doc[0 : len(doc)]
token_char_starts = [token.idx for token in doc]
token_char_ends = [token.idx + len(token.text) for token in doc]
for i in range(len(doc.text)):
if i not in token_char_starts:
with pytest.raises(ValueError):
span.start_char = i
else:
span.start_char = i
span = doc[0 : len(doc)]
for i in range(len(doc.text)):
if i not in token_char_ends:
with pytest.raises(ValueError):
span.end_char = i
else:
span.end_char = i
# start must be <= end
span = doc[1:3]
with pytest.raises(ValueError):
span.start = 4
with pytest.raises(ValueError):
span.end = 0
span = doc.char_span(2, 8)
with pytest.raises(ValueError):
span.start_char = 9
with pytest.raises(ValueError):
span.end_char = 1

View File

@ -772,36 +772,61 @@ cdef class Span:
return self.span_c().start
def __set__(self, int start):
if start < 0:
raise IndexError(Errors.E1032.format(var="start", forbidden="< 0", value=start))
self.span_c().start = start
if start < 0 or start > self.doc.length:
raise IndexError(Errors.E1032.format(var="start", obj="Doc", length=self.doc.length, value=start))
cdef SpanC* span_c = self.span_c()
if start > span_c.end:
raise ValueError(Errors.E4007.format(var="start", value=start, op="<=", existing_var="end", existing_value=span_c.end))
span_c.start = start
span_c.start_char = self.doc.c[start].idx
property end:
def __get__(self):
return self.span_c().end
def __set__(self, int end):
if end < 0:
raise IndexError(Errors.E1032.format(var="end", forbidden="< 0", value=end))
self.span_c().end = end
if end < 0 or end > self.doc.length:
raise IndexError(Errors.E1032.format(var="end", obj="Doc", length=self.doc.length, value=end))
cdef SpanC* span_c = self.span_c()
if span_c.start > end:
raise ValueError(Errors.E4007.format(var="end", value=end, op=">=", existing_var="start", existing_value=span_c.start))
span_c.end = end
if end > 0:
span_c.end_char = self.doc.c[end-1].idx + self.doc.c[end-1].lex.length
else:
span_c.end_char = 0
property start_char:
def __get__(self):
return self.span_c().start_char
def __set__(self, int start_char):
if start_char < 0:
raise IndexError(Errors.E1032.format(var="start_char", forbidden="< 0", value=start_char))
self.span_c().start_char = start_char
if start_char < 0 or start_char > len(self.doc.text):
raise IndexError(Errors.E1032.format(var="start_char", obj="Doc text", length=len(self.doc.text), value=start_char))
cdef int start = token_by_start(self.doc.c, self.doc.length, start_char)
if start < 0:
raise ValueError(Errors.E4008.format(value=start_char, pos="start"))
cdef SpanC* span_c = self.span_c()
if start_char > span_c.end_char:
raise ValueError(Errors.E4007.format(var="start_char", value=start_char, op="<=", existing_var="end_char", existing_value=span_c.end_char))
span_c.start_char = start_char
span_c.start = start
property end_char:
def __get__(self):
return self.span_c().end_char
def __set__(self, int end_char):
if end_char < 0:
raise IndexError(Errors.E1032.format(var="end_char", forbidden="< 0", value=end_char))
self.span_c().end_char = end_char
if end_char < 0 or end_char > len(self.doc.text):
raise IndexError(Errors.E1032.format(var="end_char", obj="Doc text", length=len(self.doc.text), value=end_char))
cdef int end = token_by_end(self.doc.c, self.doc.length, end_char)
if end < 0:
raise ValueError(Errors.E4008.format(value=end_char, pos="end"))
cdef SpanC* span_c = self.span_c()
if span_c.start_char > end_char:
raise ValueError(Errors.E4007.format(var="end_char", value=end_char, op=">=", existing_var="start_char", existing_value=span_c.start_char))
span_c.end_char = end_char
span_c.end = end
property label:
def __get__(self):