Add Span index boundary checks (#5861)

* Add Span index boundary checks

* Return Span-specific IndexError in all cases

* Simplify and fix if/else
This commit is contained in:
Adriane Boyd 2020-08-04 13:35:25 +02:00 committed by GitHub
parent cd59979ab4
commit b841248589
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 2 deletions

View File

@ -588,6 +588,7 @@ class Errors(object):
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
E200 = ("Specifying a base model with a pretrained component '{component}' "
"can not be combined with adding a pretrained Tok2Vec layer.")
E201 = ("Span index out of range.")
@add_codes

View File

@ -287,3 +287,15 @@ def test_span_eq_hash(doc, doc_not_parsed):
assert hash(doc[0:2]) == hash(doc[0:2])
assert hash(doc[0:2]) != hash(doc[1:3])
assert hash(doc[0:2]) != hash(doc_not_parsed[0:2])
def test_span_boundaries(doc):
start = 1
end = 5
span = doc[start:end]
for i in range(start, end):
assert span[i - start] == doc[i]
with pytest.raises(IndexError):
_ = span[-5]
with pytest.raises(IndexError):
_ = span[5]

View File

@ -181,9 +181,13 @@ cdef class Span:
return Span(self.doc, start + self.start, end + self.start)
else:
if i < 0:
return self.doc[self.end + i]
token_i = self.end + i
else:
return self.doc[self.start + i]
token_i = self.start + i
if self.start <= token_i < self.end:
return self.doc[token_i]
else:
raise IndexError(Errors.E201)
def __iter__(self):
"""Iterate over `Token` objects.