Fix realloc in retokenizer.split() (#4606)

Always realloc to a size larger than `doc.max_length` in
`retokenizer.split()` (or cymem will throw errors).
This commit is contained in:
adrianeboyd 2019-11-11 16:26:46 +01:00 committed by Matthew Honnibal
parent f415e9b7d1
commit 91f89f9693
2 changed files with 16 additions and 1 deletions

View File

@ -183,3 +183,18 @@ def test_doc_retokenizer_split_lex_attrs(en_vocab):
retokenizer.split(doc[0], ["Los", "Angeles"], heads, attrs=attrs) retokenizer.split(doc[0], ["Los", "Angeles"], heads, attrs=attrs)
assert doc[0].is_stop assert doc[0].is_stop
assert not doc[1].is_stop assert not doc[1].is_stop
def test_doc_retokenizer_realloc(en_vocab):
"""#4604: realloc correctly when new tokens outnumber original tokens"""
text = "Hyperglycemic adverse events following antipsychotic drug administration in the"
doc = Doc(en_vocab, words=text.split()[:-1])
with doc.retokenize() as retokenizer:
token = doc[0]
heads = [(token, 0)] * len(token)
retokenizer.split(doc[token.i], list(token.text), heads=heads)
doc = Doc(en_vocab, words=text.split())
with doc.retokenize() as retokenizer:
token = doc[0]
heads = [(token, 0)] * len(token)
retokenizer.split(doc[token.i], list(token.text), heads=heads)

View File

@ -329,7 +329,7 @@ def _split(Doc doc, int token_index, orths, heads, attrs):
doc.c[i].head += offset doc.c[i].head += offset
# Double doc.c max_length if necessary (until big enough for all new tokens) # Double doc.c max_length if necessary (until big enough for all new tokens)
while doc.length + nb_subtokens - 1 >= doc.max_length: while doc.length + nb_subtokens - 1 >= doc.max_length:
doc._realloc(doc.length * 2) doc._realloc(doc.max_length * 2)
# Move tokens after the split to create space for the new tokens # Move tokens after the split to create space for the new tokens
doc.length = len(doc) + nb_subtokens -1 doc.length = len(doc) + nb_subtokens -1
to_process_tensor = (doc.tensor is not None and doc.tensor.size != 0) to_process_tensor = (doc.tensor is not None and doc.tensor.size != 0)