Fix tensor retokenization for non-numpy ops (#7527)

Implement manual `append` and `delete` for non-numpy ops.
This commit is contained in:
Adriane Boyd 2021-03-29 13:34:48 +02:00 committed by GitHub
parent 139f655f34
commit 3ae8661085
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -295,7 +295,19 @@ def _resize_tensor(tensor, ranges):
for i in range(start, end-1):
delete.append(i)
xp = get_array_module(tensor)
if xp is numpy:
return xp.delete(tensor, delete, axis=0)
else:
offset = 0
copy_start = 0
resized_shape = (tensor.shape[0] - len(delete), tensor.shape[1])
for start, end in ranges:
if copy_start > 0:
tensor[copy_start - offset:start - offset] = tensor[copy_start: start]
offset += end - start - 1
copy_start = end - 1
tensor[copy_start - offset:resized_shape[0]] = tensor[copy_start:]
return xp.asarray(tensor[:resized_shape[0]])
def _split(Doc doc, int token_index, orths, heads, attrs):
@ -332,7 +344,13 @@ def _split(Doc doc, int token_index, orths, heads, attrs):
to_process_tensor = (doc.tensor is not None and doc.tensor.size != 0)
if to_process_tensor:
xp = get_array_module(doc.tensor)
if xp is numpy:
doc.tensor = xp.append(doc.tensor, xp.zeros((nb_subtokens,doc.tensor.shape[1]), dtype="float32"), axis=0)
else:
shape = (doc.tensor.shape[0] + nb_subtokens, doc.tensor.shape[1])
resized_array = xp.zeros(shape, dtype="float32")
resized_array[:doc.tensor.shape[0]] = doc.tensor[:doc.tensor.shape[0]]
doc.tensor = resized_array
for token_to_move in range(orig_length - 1, token_index, -1):
doc.c[token_to_move + nb_subtokens - 1] = doc.c[token_to_move]
if to_process_tensor:
@ -349,7 +367,7 @@ def _split(Doc doc, int token_index, orths, heads, attrs):
token.norm = 0 # reset norm
if to_process_tensor:
# setting the tensors of the split tokens to array of zeros
doc.tensor[token_index + i] = xp.zeros((1,doc.tensor.shape[1]), dtype="float32")
doc.tensor[token_index + i:token_index + i + 1] = xp.zeros((1,doc.tensor.shape[1]), dtype="float32")
# Update the character offset of the subtokens
if i != 0:
token.idx = orig_token.idx + idx_offset