Fix spancat for empty docs and zero suggestions (#9654)

* Fix spancat for empty docs and zero suggestions

* Use ops.xp.zeros in test
This commit is contained in:
Adriane Boyd 2021-11-15 12:40:55 +01:00 committed by GitHub
parent 67d8c8a081
commit c9baf9d196
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 4 deletions

View File

@ -28,7 +28,13 @@ def forward(
X, spans = source_spans
assert spans.dataXd.ndim == 2
indices = _get_span_indices(ops, spans, X.lengths)
Y = Ragged(X.dataXd[indices], spans.dataXd[:, 1] - spans.dataXd[:, 0]) # type: ignore[arg-type, index]
if len(indices) > 0:
Y = Ragged(X.dataXd[indices], spans.dataXd[:, 1] - spans.dataXd[:, 0]) # type: ignore[arg-type, index]
else:
Y = Ragged(
ops.xp.zeros(X.dataXd.shape, dtype=X.dataXd.dtype),
ops.xp.zeros((len(X.lengths),), dtype="i"),
)
x_shape = X.dataXd.shape
x_lengths = X.lengths
@ -53,7 +59,7 @@ def _get_span_indices(ops, spans: Ragged, lengths: Ints1d) -> Ints1d:
for j in range(spans_i.shape[0]):
indices.append(ops.xp.arange(spans_i[j, 0], spans_i[j, 1])) # type: ignore[call-overload, index]
offset += length
return ops.flatten(indices)
return ops.flatten(indices, dtype="i", ndim_if_empty=1)
def _ensure_cpu(spans: Ragged, lengths: Ints1d) -> Tuple[Ragged, Ints1d]:

View File

@ -78,7 +78,7 @@ def build_ngram_suggester(sizes: List[int]) -> Suggester:
if len(spans) > 0:
output = Ragged(ops.xp.vstack(spans), lengths_array)
else:
output = Ragged(ops.xp.zeros((0, 0)), lengths_array)
output = Ragged(ops.xp.zeros((0, 0), dtype="i"), lengths_array)
assert output.dataXd.ndim == 2
return output

View File

@ -1,7 +1,7 @@
import pytest
import numpy
from numpy.testing import assert_array_equal, assert_almost_equal
from thinc.api import get_current_ops
from thinc.api import get_current_ops, Ragged
from spacy import util
from spacy.lang.en import English
@ -29,6 +29,7 @@ TRAIN_DATA_OVERLAPPING = [
"I like London and Berlin",
{"spans": {SPAN_KEY: [(7, 13, "LOC"), (18, 24, "LOC"), (7, 24, "DOUBLE_LOC")]}},
),
("", {"spans": {SPAN_KEY: []}}),
]
@ -365,3 +366,31 @@ def test_overfitting_IO_overlapping():
"London and Berlin",
}
assert set([span.label_ for span in spans2]) == {"LOC", "DOUBLE_LOC"}
def test_zero_suggestions():
# Test with a suggester that returns 0 suggestions
@registry.misc("test_zero_suggester")
def make_zero_suggester():
def zero_suggester(docs, *, ops=None):
if ops is None:
ops = get_current_ops()
return Ragged(
ops.xp.zeros((0, 0), dtype="i"), ops.xp.zeros((len(docs),), dtype="i")
)
return zero_suggester
fix_random_seed(0)
nlp = English()
spancat = nlp.add_pipe(
"spancat",
config={"suggester": {"@misc": "test_zero_suggester"}, "spans_key": SPAN_KEY},
)
train_examples = make_examples(nlp)
optimizer = nlp.initialize(get_examples=lambda: train_examples)
assert spancat.model.get_dim("nO") == 2
assert set(spancat.labels) == {"LOC", "PERSON"}
nlp.update(train_examples, sgd=optimizer)