diff --git a/spacy/ml/extract_spans.py b/spacy/ml/extract_spans.py index 503873e96..d5e9bc07c 100644 --- a/spacy/ml/extract_spans.py +++ b/spacy/ml/extract_spans.py @@ -32,8 +32,8 @@ def forward( Y = Ragged(X.dataXd[indices], spans.dataXd[:, 1] - spans.dataXd[:, 0]) # type: ignore[arg-type, index] else: Y = Ragged( - ops.xp.zeros((0, 0), dtype=X.dataXd.dtype), - ops.xp.zeros((0,), dtype="i"), + ops.xp.zeros(X.dataXd.shape, dtype=X.dataXd.dtype), + ops.xp.zeros((len(X.lengths),), dtype="i"), ) x_shape = X.dataXd.shape x_lengths = X.lengths