Improve token head verification (#5079)

* Improve token head verification

Improve the verification for valid token heads when heads are set:

* in `Token.head`: heads come from the same document
* in `Doc.from_array()`: head indices are within the bounds of the
document

* Improve error message
This commit is contained in:
adrianeboyd 2020-03-03 21:44:51 +01:00 committed by GitHub
parent 8c20dae6f7
commit 9be90dbca3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 51 additions and 1 deletions

View File

@ -545,6 +545,13 @@ class Errors(object):
"make sure the gold EL data refers to valid results of the " "make sure the gold EL data refers to valid results of the "
"named entity recognizer in the `nlp` pipeline.") "named entity recognizer in the `nlp` pipeline.")
E189 = ("Each argument to `get_doc` should be of equal length.") E189 = ("Each argument to `get_doc` should be of equal length.")
E190 = ("Token head out of range in `Doc.from_array()` for token index "
"'{index}' with value '{value}' (equivalent to relative head "
"index: '{rel_head_index}'). The head indices should be relative "
"to the current token index rather than absolute indices in the "
"array.")
E191 = ("Invalid head: the head token must be from the same doc as the "
"token itself.")
@add_codes @add_codes

View File

@ -77,3 +77,30 @@ def test_doc_array_idx(en_vocab):
assert offsets[0] == 0 assert offsets[0] == 0
assert offsets[1] == 3 assert offsets[1] == 3
assert offsets[2] == 11 assert offsets[2] == 11
def test_doc_from_array_heads_in_bounds(en_vocab):
"""Test that Doc.from_array doesn't set heads that are out of bounds."""
words = ["This", "is", "a", "sentence", "."]
doc = Doc(en_vocab, words=words)
for token in doc:
token.head = doc[0]
# correct
arr = doc.to_array(["HEAD"])
doc_from_array = Doc(en_vocab, words=words)
doc_from_array.from_array(["HEAD"], arr)
# head before start
arr = doc.to_array(["HEAD"])
arr[0] = -1
doc_from_array = Doc(en_vocab, words=words)
with pytest.raises(ValueError):
doc_from_array.from_array(["HEAD"], arr)
# head after end
arr = doc.to_array(["HEAD"])
arr[0] = 5
doc_from_array = Doc(en_vocab, words=words)
with pytest.raises(ValueError):
doc_from_array.from_array(["HEAD"], arr)

View File

@ -167,6 +167,11 @@ def test_doc_token_api_head_setter(en_tokenizer):
assert doc[4].left_edge.i == 0 assert doc[4].left_edge.i == 0
assert doc[2].left_edge.i == 0 assert doc[2].left_edge.i == 0
# head token must be from the same document
doc2 = get_doc(tokens.vocab, words=[t.text for t in tokens], heads=heads)
with pytest.raises(ValueError):
doc[0].head = doc2[0]
def test_is_sent_start(en_tokenizer): def test_is_sent_start(en_tokenizer):
doc = en_tokenizer("This is a sentence. This is another.") doc = en_tokenizer("This is a sentence. This is another.")

View File

@ -790,7 +790,7 @@ cdef class Doc:
if SENT_START in attrs and HEAD in attrs: if SENT_START in attrs and HEAD in attrs:
raise ValueError(Errors.E032) raise ValueError(Errors.E032)
cdef int i, col cdef int i, col, abs_head_index
cdef attr_id_t attr_id cdef attr_id_t attr_id
cdef TokenC* tokens = self.c cdef TokenC* tokens = self.c
cdef int length = len(array) cdef int length = len(array)
@ -804,6 +804,14 @@ cdef class Doc:
attr_ids[i] = attr_id attr_ids[i] = attr_id
if len(array.shape) == 1: if len(array.shape) == 1:
array = array.reshape((array.size, 1)) array = array.reshape((array.size, 1))
# Check that all heads are within the document bounds
if HEAD in attrs:
col = attrs.index(HEAD)
for i in range(length):
# cast index to signed int
abs_head_index = numpy.int32(array[i, col]) + i
if abs_head_index < 0 or abs_head_index >= length:
raise ValueError(Errors.E190.format(index=i, value=array[i, col], rel_head_index=numpy.int32(array[i, col])))
# Do TAG first. This lets subsequent loop override stuff like POS, LEMMA # Do TAG first. This lets subsequent loop override stuff like POS, LEMMA
if TAG in attrs: if TAG in attrs:
col = attrs.index(TAG) col = attrs.index(TAG)

View File

@ -623,6 +623,9 @@ cdef class Token:
# This function sets the head of self to new_head and updates the # This function sets the head of self to new_head and updates the
# counters for left/right dependents and left/right corner for the # counters for left/right dependents and left/right corner for the
# new and the old head # new and the old head
# Check that token is from the same document
if self.doc != new_head.doc:
raise ValueError(Errors.E191)
# Do nothing if old head is new head # Do nothing if old head is new head
if self.i + self.c.head == new_head.i: if self.i + self.c.head == new_head.i:
return return