mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-05 13:43:24 +03:00
single label make_spangroup test
This commit is contained in:
parent
6fc25f64dd
commit
ec941a128d
|
@ -108,13 +108,13 @@ def test_doc_gc():
|
|||
# XXX This fails with length 0 sometimes
|
||||
assert len(spangroup) > 0
|
||||
with pytest.raises(RuntimeError):
|
||||
span = spangroup[0]
|
||||
spangroup[0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"max_positive,nr_results", [(None, 4), (1, 2), (2, 3), (3, 4), (4, 4)]
|
||||
)
|
||||
def test_make_spangroup(max_positive, nr_results):
|
||||
def test_make_spangroup_multilabel(max_positive, nr_results):
|
||||
fix_random_seed(0)
|
||||
nlp = Language()
|
||||
spancat = nlp.add_pipe(
|
||||
|
@ -160,6 +160,55 @@ def test_make_spangroup(max_positive, nr_results):
|
|||
assert_almost_equal(0.9, spangroup.attrs["scores"][-1], 5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"threshold,allow_overlap,nr_results", [
|
||||
(0.05, True, 3), (0.05, False, 1), (0.5, True, 2), (0.5, False, 1)]
|
||||
)
|
||||
def test_make_spangroup_singlelabel(threshold, allow_overlap, nr_results):
|
||||
fix_random_seed(0)
|
||||
nlp = Language()
|
||||
spancat = nlp.add_pipe(
|
||||
"spancat",
|
||||
config={
|
||||
"spans_key": SPAN_KEY,
|
||||
"threshold": threshold,
|
||||
"max_positive": 1,
|
||||
},
|
||||
)
|
||||
doc = nlp.make_doc("Greater London")
|
||||
ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1, 2])
|
||||
indices = ngram_suggester([doc])[0].dataXd
|
||||
assert_array_equal(OPS.to_numpy(indices), numpy.asarray([[0, 1], [1, 2], [0, 2]]))
|
||||
labels = ["Thing", "City", "Person", "GreatCity"]
|
||||
scores = numpy.asarray(
|
||||
[[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f"
|
||||
)
|
||||
spangroup = spancat._make_span_group_singlelabel(
|
||||
doc, indices, scores, labels, allow_overlap
|
||||
)
|
||||
assert len(spangroup) == nr_results
|
||||
if threshold > 0.4:
|
||||
if allow_overlap:
|
||||
assert spangroup[0].text == "London"
|
||||
assert spangroup[0].label_ == "City"
|
||||
assert spangroup[1].text == "Greater London"
|
||||
assert spangroup[1].label_ == "GreatCity"
|
||||
|
||||
else:
|
||||
assert spangroup[0].text == "Greater London"
|
||||
assert spangroup[0].label_ == "GreatCity"
|
||||
else:
|
||||
if allow_overlap:
|
||||
assert spangroup[0].text == "Greater"
|
||||
assert spangroup[0].label_ == "City"
|
||||
assert spangroup[1].text == "London"
|
||||
assert spangroup[1].label_ == "City"
|
||||
assert spangroup[2].text == "Greater London"
|
||||
assert spangroup[2].label_ == "GreatCity"
|
||||
else:
|
||||
assert spangroup[0].text == "Greater London"
|
||||
|
||||
|
||||
def test_ngram_suggester(en_tokenizer):
|
||||
# test different n-gram lengths
|
||||
for size in [1, 2, 3]:
|
||||
|
|
Loading…
Reference in New Issue
Block a user