Enforce that Span.start/end(_char) remain valid and in sync (#12268)

* Enforce that Span.start/end(_char) remain valid and in sync

Allowing span attributes to be writable starting in v3 has made it
possible for the internal `Span.start/end/start_char/end_char` to get
out-of-sync or have invalid values.

This checks that the values are valid and syncs the token and char
offsets if any attributes are modified directly. It does not yet handle
the case where the underlying doc is modified.

* Format
This commit is contained in:
Adriane Boyd 2023-04-06 16:01:59 +02:00 committed by GitHub
parent b734e5314d
commit 5d0f48fe69
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 88 additions and 13 deletions

View File

@ -926,7 +926,7 @@ class Errors(metaclass=ErrorsWithCodes):
E1029 = ("Edit tree cannot be applied to form.") E1029 = ("Edit tree cannot be applied to form.")
E1030 = ("Edit tree identifier out of range.") E1030 = ("Edit tree identifier out of range.")
E1031 = ("Could not find gold transition - see logs above.") E1031 = ("Could not find gold transition - see logs above.")
E1032 = ("`{var}` should not be {forbidden}, but received {value}.") E1032 = ("Span {var} {value} is out of bounds for {obj} with length {length}.")
E1033 = ("Dimension {name} invalid -- only nO, nF, nP") E1033 = ("Dimension {name} invalid -- only nO, nF, nP")
E1034 = ("Node index {i} out of bounds ({length})") E1034 = ("Node index {i} out of bounds ({length})")
E1035 = ("Token index {i} out of bounds ({length})") E1035 = ("Token index {i} out of bounds ({length})")
@ -966,6 +966,9 @@ class Errors(metaclass=ErrorsWithCodes):
E4004 = ("Backprop is not supported when is_train is not set.") E4004 = ("Backprop is not supported when is_train is not set.")
E4005 = ("EntityLinker_v1 is not supported in spaCy v4. Update your configuration.") E4005 = ("EntityLinker_v1 is not supported in spaCy v4. Update your configuration.")
E4006 = ("Expected `entity_id` to be of type {exp_type}, but is of type {found_type}.") E4006 = ("Expected `entity_id` to be of type {exp_type}, but is of type {found_type}.")
E4007 = ("Span {var} {value} must be {op} Span {existing_var} "
"{existing_value}.")
E4008 = ("Span {pos}_char {value} does not correspond to a token {pos}.")
RENAMED_LANGUAGE_CODES = {"xx": "mul", "is": "isl"} RENAMED_LANGUAGE_CODES = {"xx": "mul", "is": "isl"}

View File

@ -707,3 +707,50 @@ def test_span_ent_id(en_tokenizer):
doc.ents = [span] doc.ents = [span]
assert doc.ents[0].ent_id_ == "ID2" assert doc.ents[0].ent_id_ == "ID2"
assert doc[1].ent_id_ == "ID2" assert doc[1].ent_id_ == "ID2"
def test_span_start_end_sync(en_tokenizer):
doc = en_tokenizer("a bc def e fghij kl")
# can create and edit span starts/ends
span = doc[2:4]
span.start_char = 2
span.end = 5
assert span == doc[span.start : span.end]
assert span == doc.char_span(span.start_char, span.end_char)
# cannot set completely out of bounds starts/ends
with pytest.raises(IndexError):
span.start = -1
with pytest.raises(IndexError):
span.end = -1
with pytest.raises(IndexError):
span.start_char = len(doc.text) + 1
with pytest.raises(IndexError):
span.end = len(doc.text) + 1
# test all possible char starts/ends
span = doc[0 : len(doc)]
token_char_starts = [token.idx for token in doc]
token_char_ends = [token.idx + len(token.text) for token in doc]
for i in range(len(doc.text)):
if i not in token_char_starts:
with pytest.raises(ValueError):
span.start_char = i
else:
span.start_char = i
span = doc[0 : len(doc)]
for i in range(len(doc.text)):
if i not in token_char_ends:
with pytest.raises(ValueError):
span.end_char = i
else:
span.end_char = i
# start must be <= end
span = doc[1:3]
with pytest.raises(ValueError):
span.start = 4
with pytest.raises(ValueError):
span.end = 0
span = doc.char_span(2, 8)
with pytest.raises(ValueError):
span.start_char = 9
with pytest.raises(ValueError):
span.end_char = 1

View File

@ -772,36 +772,61 @@ cdef class Span:
return self.span_c().start return self.span_c().start
def __set__(self, int start): def __set__(self, int start):
if start < 0: if start < 0 or start > self.doc.length:
raise IndexError(Errors.E1032.format(var="start", forbidden="< 0", value=start)) raise IndexError(Errors.E1032.format(var="start", obj="Doc", length=self.doc.length, value=start))
self.span_c().start = start cdef SpanC* span_c = self.span_c()
if start > span_c.end:
raise ValueError(Errors.E4007.format(var="start", value=start, op="<=", existing_var="end", existing_value=span_c.end))
span_c.start = start
span_c.start_char = self.doc.c[start].idx
property end: property end:
def __get__(self): def __get__(self):
return self.span_c().end return self.span_c().end
def __set__(self, int end): def __set__(self, int end):
if end < 0: if end < 0 or end > self.doc.length:
raise IndexError(Errors.E1032.format(var="end", forbidden="< 0", value=end)) raise IndexError(Errors.E1032.format(var="end", obj="Doc", length=self.doc.length, value=end))
self.span_c().end = end cdef SpanC* span_c = self.span_c()
if span_c.start > end:
raise ValueError(Errors.E4007.format(var="end", value=end, op=">=", existing_var="start", existing_value=span_c.start))
span_c.end = end
if end > 0:
span_c.end_char = self.doc.c[end-1].idx + self.doc.c[end-1].lex.length
else:
span_c.end_char = 0
property start_char: property start_char:
def __get__(self): def __get__(self):
return self.span_c().start_char return self.span_c().start_char
def __set__(self, int start_char): def __set__(self, int start_char):
if start_char < 0: if start_char < 0 or start_char > len(self.doc.text):
raise IndexError(Errors.E1032.format(var="start_char", forbidden="< 0", value=start_char)) raise IndexError(Errors.E1032.format(var="start_char", obj="Doc text", length=len(self.doc.text), value=start_char))
self.span_c().start_char = start_char cdef int start = token_by_start(self.doc.c, self.doc.length, start_char)
if start < 0:
raise ValueError(Errors.E4008.format(value=start_char, pos="start"))
cdef SpanC* span_c = self.span_c()
if start_char > span_c.end_char:
raise ValueError(Errors.E4007.format(var="start_char", value=start_char, op="<=", existing_var="end_char", existing_value=span_c.end_char))
span_c.start_char = start_char
span_c.start = start
property end_char: property end_char:
def __get__(self): def __get__(self):
return self.span_c().end_char return self.span_c().end_char
def __set__(self, int end_char): def __set__(self, int end_char):
if end_char < 0: if end_char < 0 or end_char > len(self.doc.text):
raise IndexError(Errors.E1032.format(var="end_char", forbidden="< 0", value=end_char)) raise IndexError(Errors.E1032.format(var="end_char", obj="Doc text", length=len(self.doc.text), value=end_char))
self.span_c().end_char = end_char cdef int end = token_by_end(self.doc.c, self.doc.length, end_char)
if end < 0:
raise ValueError(Errors.E4008.format(value=end_char, pos="end"))
cdef SpanC* span_c = self.span_c()
if span_c.start_char > end_char:
raise ValueError(Errors.E4007.format(var="end_char", value=end_char, op=">=", existing_var="start_char", existing_value=span_c.start_char))
span_c.end_char = end_char
span_c.end = end
property label: property label:
def __get__(self): def __get__(self):