Fix making span_group (#8975)

* fix _make_span_group

* fix imports
This commit is contained in:
Sofie Van Landeghem 2021-08-17 10:36:34 +02:00 committed by GitHub
parent 593a22cf2d
commit 0a6b68848f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 58 additions and 2 deletions

View File

@ -411,7 +411,10 @@ class SpanCategorizer(TrainablePipe):
keeps = scores >= threshold
ranked = (scores * -1).argsort()
keeps[ranked[:, max_positive:]] = False
if max_positive is not None:
filter = ranked[:, max_positive:]
for i, row in enumerate(filter):
keeps[i, row] = False
spans.attrs["scores"] = scores[keeps].flatten()
indices = self.model.ops.to_numpy(indices)

View File

@ -1,5 +1,6 @@
import pytest
from numpy.testing import assert_equal, assert_array_equal
import numpy
from numpy.testing import assert_equal, assert_array_equal, assert_almost_equal
from thinc.api import get_current_ops
from spacy.language import Language
from spacy.training import Example
@ -71,6 +72,55 @@ def test_explicit_labels():
assert spancat.labels == ("PERSON", "LOC")
@pytest.mark.parametrize(
"max_positive,nr_results", [(None, 4), (1, 2), (2, 3), (3, 4), (4, 4)]
)
def test_make_spangroup(max_positive, nr_results):
fix_random_seed(0)
nlp = Language()
spancat = nlp.add_pipe(
"spancat",
config={"spans_key": SPAN_KEY, "threshold": 0.5, "max_positive": max_positive},
)
doc = nlp.make_doc("Greater London")
ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1, 2])
indices = ngram_suggester([doc])[0].dataXd
assert_array_equal(indices, numpy.asarray([[0, 1], [1, 2], [0, 2]]))
labels = ["Thing", "City", "Person", "GreatCity"]
scores = numpy.asarray(
[[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f"
)
spangroup = spancat._make_span_group(doc, indices, scores, labels)
assert len(spangroup) == nr_results
# first span is always the second token "London"
assert spangroup[0].text == "London"
assert spangroup[0].label_ == "City"
assert_almost_equal(0.6, spangroup.attrs["scores"][0], 5)
# second span depends on the number of positives that were allowed
assert spangroup[1].text == "Greater London"
if max_positive == 1:
assert spangroup[1].label_ == "GreatCity"
assert_almost_equal(0.9, spangroup.attrs["scores"][1], 5)
else:
assert spangroup[1].label_ == "Thing"
assert_almost_equal(0.8, spangroup.attrs["scores"][1], 5)
if nr_results > 2:
assert spangroup[2].text == "Greater London"
if max_positive == 2:
assert spangroup[2].label_ == "GreatCity"
assert_almost_equal(0.9, spangroup.attrs["scores"][2], 5)
else:
assert spangroup[2].label_ == "City"
assert_almost_equal(0.7, spangroup.attrs["scores"][2], 5)
assert spangroup[-1].text == "Greater London"
assert spangroup[-1].label_ == "GreatCity"
assert_almost_equal(0.9, spangroup.attrs["scores"][-1], 5)
def test_simple_train():
fix_random_seed(0)
nlp = Language()
@ -90,6 +140,9 @@ def test_simple_train():
scores = nlp.evaluate(get_examples())
assert f"spans_{SPAN_KEY}_f" in scores
assert scores[f"spans_{SPAN_KEY}_f"] == 1.0
# also test that the spancat works for just a single entity in a sentence
doc = nlp("London")
assert len(doc.spans[spancat.key]) == 1
def test_ngram_suggester(en_tokenizer):