mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-22 01:51:58 +03:00
Enforce that Span.start/end(_char) remain valid and in sync
Allowing span attributes to be writable starting in v3 has made it possible for the internal `Span.start/end/start_char/end_char` to get out-of-sync or have invalid values. This checks that the values are valid and syncs the token and char offsets if any attributes are modified directly. It does not yet handle the case where the underlying doc is modified.
This commit is contained in:
parent
cbc2ae933e
commit
32b3ca3d88
|
@ -923,7 +923,7 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
E1029 = ("Edit tree cannot be applied to form.")
|
||||
E1030 = ("Edit tree identifier out of range.")
|
||||
E1031 = ("Could not find gold transition - see logs above.")
|
||||
E1032 = ("`{var}` should not be {forbidden}, but received {value}.")
|
||||
E1032 = ("Span {var} {value} is out of bounds for {obj} with length {length}.")
|
||||
E1033 = ("Dimension {name} invalid -- only nO, nF, nP")
|
||||
E1034 = ("Node index {i} out of bounds ({length})")
|
||||
E1035 = ("Token index {i} out of bounds ({length})")
|
||||
|
@ -961,6 +961,9 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
E4003 = ("Training examples for distillation must have the exact same tokens in the "
|
||||
"reference and predicted docs.")
|
||||
E4004 = ("Backprop is not supported when is_train is not set.")
|
||||
E4005 = ("Span {var} {value} must be {op} Span {existing_var} "
|
||||
"{existing_value}.")
|
||||
E4006 = ("Span {pos}_char {value} does not correspond to a token {pos}.")
|
||||
|
||||
RENAMED_LANGUAGE_CODES = {"xx": "mul", "is": "isl"}
|
||||
|
||||
|
|
|
@ -707,3 +707,50 @@ def test_span_ent_id(en_tokenizer):
|
|||
doc.ents = [span]
|
||||
assert doc.ents[0].ent_id_ == "ID2"
|
||||
assert doc[1].ent_id_ == "ID2"
|
||||
|
||||
|
||||
def test_span_start_end_sync(en_tokenizer):
|
||||
doc = en_tokenizer("a bc def e fghij kl")
|
||||
# can create and edit span starts/ends
|
||||
span = doc[2:4]
|
||||
span.start_char = 2
|
||||
span.end = 5
|
||||
assert span == doc[span.start : span.end]
|
||||
assert span == doc.char_span(span.start_char, span.end_char)
|
||||
# cannot set completely out of bounds starts/ends
|
||||
with pytest.raises(IndexError):
|
||||
span.start = -1
|
||||
with pytest.raises(IndexError):
|
||||
span.end = -1
|
||||
with pytest.raises(IndexError):
|
||||
span.start_char = len(doc.text) + 1
|
||||
with pytest.raises(IndexError):
|
||||
span.end = len(doc.text) + 1
|
||||
# test all possible char starts/ends
|
||||
span = doc[0 : len(doc)]
|
||||
token_char_starts = [token.idx for token in doc]
|
||||
token_char_ends = [token.idx + len(token.text) for token in doc]
|
||||
for i in range(len(doc.text)):
|
||||
if i not in token_char_starts:
|
||||
with pytest.raises(ValueError):
|
||||
span.start_char = i
|
||||
else:
|
||||
span.start_char = i
|
||||
span = doc[0 : len(doc)]
|
||||
for i in range(len(doc.text)):
|
||||
if i not in token_char_ends:
|
||||
with pytest.raises(ValueError):
|
||||
span.end_char = i
|
||||
else:
|
||||
span.end_char = i
|
||||
# start must be <= end
|
||||
span = doc[1:3]
|
||||
with pytest.raises(ValueError):
|
||||
span.start = 4
|
||||
with pytest.raises(ValueError):
|
||||
span.end = 0
|
||||
span = doc.char_span(2, 8)
|
||||
with pytest.raises(ValueError):
|
||||
span.start_char = 9
|
||||
with pytest.raises(ValueError):
|
||||
span.end_char = 1
|
||||
|
|
|
@ -770,36 +770,61 @@ cdef class Span:
|
|||
return self.span_c().start
|
||||
|
||||
def __set__(self, int start):
|
||||
if start < 0:
|
||||
raise IndexError(Errors.E1032.format(var="start", forbidden="< 0", value=start))
|
||||
self.span_c().start = start
|
||||
if start < 0 or start > self.doc.length:
|
||||
raise IndexError(Errors.E1032.format(var="start", obj="Doc", length=self.doc.length, value=start))
|
||||
cdef SpanC* span_c = self.span_c()
|
||||
if start > span_c.end:
|
||||
raise ValueError(Errors.E4005.format(var="start", value=start, op="<=", existing_var="end", existing_value=span_c.end))
|
||||
span_c.start = start
|
||||
span_c.start_char = self.doc.c[start].idx
|
||||
|
||||
property end:
|
||||
def __get__(self):
|
||||
return self.span_c().end
|
||||
|
||||
def __set__(self, int end):
|
||||
if end < 0:
|
||||
raise IndexError(Errors.E1032.format(var="end", forbidden="< 0", value=end))
|
||||
self.span_c().end = end
|
||||
if end < 0 or end > self.doc.length:
|
||||
raise IndexError(Errors.E1032.format(var="end", obj="Doc", length=self.doc.length, value=end))
|
||||
cdef SpanC* span_c = self.span_c()
|
||||
if span_c.start > end:
|
||||
raise ValueError(Errors.E4005.format(var="end", value=end, op=">=", existing_var="start", existing_value=span_c.start))
|
||||
span_c.end = end
|
||||
if end > 0:
|
||||
span_c.end_char = self.doc.c[end-1].idx + self.doc.c[end-1].lex.length
|
||||
else:
|
||||
span_c.end_char = 0
|
||||
|
||||
property start_char:
|
||||
def __get__(self):
|
||||
return self.span_c().start_char
|
||||
|
||||
def __set__(self, int start_char):
|
||||
if start_char < 0:
|
||||
raise IndexError(Errors.E1032.format(var="start_char", forbidden="< 0", value=start_char))
|
||||
self.span_c().start_char = start_char
|
||||
if start_char < 0 or start_char > len(self.doc.text):
|
||||
raise IndexError(Errors.E1032.format(var="start_char", obj="Doc text", length=len(self.doc.text), value=start_char))
|
||||
cdef int start = token_by_start(self.doc.c, self.doc.length, start_char)
|
||||
if start < 0:
|
||||
raise ValueError(Errors.E4006.format(value=start_char, pos="start"))
|
||||
cdef SpanC* span_c = self.span_c()
|
||||
if start_char > span_c.end_char:
|
||||
raise ValueError(Errors.E4005.format(var="start_char", value=start_char, op="<=", existing_var="end_char", existing_value=span_c.end_char))
|
||||
span_c.start_char = start_char
|
||||
span_c.start = start
|
||||
|
||||
property end_char:
|
||||
def __get__(self):
|
||||
return self.span_c().end_char
|
||||
|
||||
def __set__(self, int end_char):
|
||||
if end_char < 0:
|
||||
raise IndexError(Errors.E1032.format(var="end_char", forbidden="< 0", value=end_char))
|
||||
self.span_c().end_char = end_char
|
||||
if end_char < 0 or end_char > len(self.doc.text):
|
||||
raise IndexError(Errors.E1032.format(var="end_char", obj="Doc text", length=len(self.doc.text), value=end_char))
|
||||
cdef int end = token_by_end(self.doc.c, self.doc.length, end_char)
|
||||
if end < 0:
|
||||
raise ValueError(Errors.E4006.format(value=end_char, pos="end"))
|
||||
cdef SpanC* span_c = self.span_c()
|
||||
if span_c.start_char > end_char:
|
||||
raise ValueError(Errors.E4005.format(var="end_char", value=end_char, op=">=", existing_var="start_char", existing_value=span_c.start_char))
|
||||
span_c.end_char = end_char
|
||||
span_c.end = end
|
||||
|
||||
property label:
|
||||
def __get__(self):
|
||||
|
|
Loading…
Reference in New Issue
Block a user