Fix spancat for zero suggestions

This commit is contained in:
Adriane Boyd 2022-11-29 11:20:43 +01:00
parent 5803cb87a4
commit df73823aae
3 changed files with 40 additions and 13 deletions

View File

@ -32,8 +32,8 @@ def forward(
Y = Ragged(X.dataXd[indices], spans.dataXd[:, 1] - spans.dataXd[:, 0]) # type: ignore[arg-type, index]
else:
Y = Ragged(
ops.xp.zeros(X.dataXd.shape, dtype=X.dataXd.dtype),
ops.xp.zeros((len(X.lengths),), dtype="i"),
ops.xp.zeros((0, 0), dtype=X.dataXd.dtype),
ops.xp.zeros((0,), dtype="i"),
)
x_shape = X.dataXd.shape
x_lengths = X.lengths

View File

@ -272,7 +272,10 @@ class SpanCategorizer(TrainablePipe):
DOCS: https://spacy.io/api/spancategorizer#predict
"""
indices = self.suggester(docs, ops=self.model.ops)
scores = self.model.predict((docs, indices)) # type: ignore
if sum(indices.lengths) == 0:
scores = self.model.ops.alloc2f(0, 0)
else:
scores = self.model.predict((docs, indices)) # type: ignore
return indices, scores
def set_candidates(

View File

@ -372,24 +372,39 @@ def test_overfitting_IO_overlapping():
def test_zero_suggestions():
# Test with a suggester that returns 0 suggestions
# Test with a suggester that can return 0 suggestions
@registry.misc("test_zero_suggester")
def make_zero_suggester():
def zero_suggester(docs, *, ops=None):
@registry.misc("test_mixed_zero_suggester")
def make_mixed_zero_suggester():
def mixed_zero_suggester(docs, *, ops=None):
if ops is None:
ops = get_current_ops()
return Ragged(
ops.xp.zeros((0, 0), dtype="i"), ops.xp.zeros((len(docs),), dtype="i")
)
spans = []
lengths = []
for doc in docs:
if len(doc) > 0 and len(doc) % 2 == 0:
spans.append((0, 1))
lengths.append(1)
else:
lengths.append(0)
spans = ops.asarray2i(spans)
lengths_array = ops.asarray1i(lengths)
if len(spans) > 0:
output = Ragged(ops.xp.vstack(spans), lengths_array)
else:
output = Ragged(ops.xp.zeros((0, 0), dtype="i"), lengths_array)
return output
return zero_suggester
return mixed_zero_suggester
fix_random_seed(0)
nlp = English()
spancat = nlp.add_pipe(
"spancat",
config={"suggester": {"@misc": "test_zero_suggester"}, "spans_key": SPAN_KEY},
config={
"suggester": {"@misc": "test_mixed_zero_suggester"},
"spans_key": SPAN_KEY,
},
)
train_examples = make_examples(nlp)
optimizer = nlp.initialize(get_examples=lambda: train_examples)
@ -397,7 +412,16 @@ def test_zero_suggestions():
assert set(spancat.labels) == {"LOC", "PERSON"}
nlp.update(train_examples, sgd=optimizer)
nlp("zero")
# empty doc
nlp("")
# single doc with zero suggestions
nlp("one")
# single doc with one suggestion
nlp("two two")
# batch with mixed zero/one suggestions
list(nlp.pipe(["one", "two two", "three three three", "", "four four four four"]))
# batch with no suggestions
list(nlp.pipe(["", "one", "three three three"]))
def test_set_candidates():