Undo changes to extract_spans

This commit is contained in:
Adriane Boyd 2022-11-29 13:17:28 +01:00
parent df73823aae
commit 4ecccf8460

View File

@ -32,8 +32,8 @@ def forward(
Y = Ragged(X.dataXd[indices], spans.dataXd[:, 1] - spans.dataXd[:, 0]) # type: ignore[arg-type, index] Y = Ragged(X.dataXd[indices], spans.dataXd[:, 1] - spans.dataXd[:, 0]) # type: ignore[arg-type, index]
else: else:
Y = Ragged( Y = Ragged(
ops.xp.zeros((0, 0), dtype=X.dataXd.dtype), ops.xp.zeros(X.dataXd.shape, dtype=X.dataXd.dtype),
ops.xp.zeros((0,), dtype="i"), ops.xp.zeros((len(X.lengths),), dtype="i"),
) )
x_shape = X.dataXd.shape x_shape = X.dataXd.shape
x_lengths = X.lengths x_lengths = X.lengths