diff --git a/spacy/tokens/_retokenize.pyx b/spacy/tokens/_retokenize.pyx index 5c7523667..43e6d4aa7 100644 --- a/spacy/tokens/_retokenize.pyx +++ b/spacy/tokens/_retokenize.pyx @@ -295,7 +295,19 @@ def _resize_tensor(tensor, ranges): for i in range(start, end-1): delete.append(i) xp = get_array_module(tensor) - return xp.delete(tensor, delete, axis=0) + if xp is numpy: + return xp.delete(tensor, delete, axis=0) + else: + offset = 0 + copy_start = 0 + resized_shape = (tensor.shape[0] - len(delete), tensor.shape[1]) + for start, end in ranges: + if copy_start > 0: + tensor[copy_start - offset:start - offset] = tensor[copy_start: start] + offset += end - start - 1 + copy_start = end - 1 + tensor[copy_start - offset:resized_shape[0]] = tensor[copy_start:] + return xp.asarray(tensor[:resized_shape[0]]) def _split(Doc doc, int token_index, orths, heads, attrs): @@ -332,7 +344,13 @@ def _split(Doc doc, int token_index, orths, heads, attrs): to_process_tensor = (doc.tensor is not None and doc.tensor.size != 0) if to_process_tensor: xp = get_array_module(doc.tensor) - doc.tensor = xp.append(doc.tensor, xp.zeros((nb_subtokens,doc.tensor.shape[1]), dtype="float32"), axis=0) + if xp is numpy: + doc.tensor = xp.append(doc.tensor, xp.zeros((nb_subtokens,doc.tensor.shape[1]), dtype="float32"), axis=0) + else: + shape = (doc.tensor.shape[0] + nb_subtokens, doc.tensor.shape[1]) + resized_array = xp.zeros(shape, dtype="float32") + resized_array[:doc.tensor.shape[0]] = doc.tensor[:doc.tensor.shape[0]] + doc.tensor = resized_array for token_to_move in range(orig_length - 1, token_index, -1): doc.c[token_to_move + nb_subtokens - 1] = doc.c[token_to_move] if to_process_tensor: @@ -349,7 +367,7 @@ def _split(Doc doc, int token_index, orths, heads, attrs): token.norm = 0 # reset norm if to_process_tensor: # setting the tensors of the split tokens to array of zeros - doc.tensor[token_index + i] = xp.zeros((1,doc.tensor.shape[1]), dtype="float32") + doc.tensor[token_index + i:token_index + i + 1] = xp.zeros((1,doc.tensor.shape[1]), dtype="float32") # Update the character offset of the subtokens if i != 0: token.idx = orig_token.idx + idx_offset