mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
parent
593a22cf2d
commit
0a6b68848f
|
@ -411,7 +411,10 @@ class SpanCategorizer(TrainablePipe):
|
|||
|
||||
keeps = scores >= threshold
|
||||
ranked = (scores * -1).argsort()
|
||||
keeps[ranked[:, max_positive:]] = False
|
||||
if max_positive is not None:
|
||||
filter = ranked[:, max_positive:]
|
||||
for i, row in enumerate(filter):
|
||||
keeps[i, row] = False
|
||||
spans.attrs["scores"] = scores[keeps].flatten()
|
||||
|
||||
indices = self.model.ops.to_numpy(indices)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import pytest
|
||||
from numpy.testing import assert_equal, assert_array_equal
|
||||
import numpy
|
||||
from numpy.testing import assert_equal, assert_array_equal, assert_almost_equal
|
||||
from thinc.api import get_current_ops
|
||||
from spacy.language import Language
|
||||
from spacy.training import Example
|
||||
|
@ -71,6 +72,55 @@ def test_explicit_labels():
|
|||
assert spancat.labels == ("PERSON", "LOC")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"max_positive,nr_results", [(None, 4), (1, 2), (2, 3), (3, 4), (4, 4)]
|
||||
)
|
||||
def test_make_spangroup(max_positive, nr_results):
|
||||
fix_random_seed(0)
|
||||
nlp = Language()
|
||||
spancat = nlp.add_pipe(
|
||||
"spancat",
|
||||
config={"spans_key": SPAN_KEY, "threshold": 0.5, "max_positive": max_positive},
|
||||
)
|
||||
doc = nlp.make_doc("Greater London")
|
||||
ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1, 2])
|
||||
indices = ngram_suggester([doc])[0].dataXd
|
||||
assert_array_equal(indices, numpy.asarray([[0, 1], [1, 2], [0, 2]]))
|
||||
labels = ["Thing", "City", "Person", "GreatCity"]
|
||||
scores = numpy.asarray(
|
||||
[[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f"
|
||||
)
|
||||
spangroup = spancat._make_span_group(doc, indices, scores, labels)
|
||||
assert len(spangroup) == nr_results
|
||||
|
||||
# first span is always the second token "London"
|
||||
assert spangroup[0].text == "London"
|
||||
assert spangroup[0].label_ == "City"
|
||||
assert_almost_equal(0.6, spangroup.attrs["scores"][0], 5)
|
||||
|
||||
# second span depends on the number of positives that were allowed
|
||||
assert spangroup[1].text == "Greater London"
|
||||
if max_positive == 1:
|
||||
assert spangroup[1].label_ == "GreatCity"
|
||||
assert_almost_equal(0.9, spangroup.attrs["scores"][1], 5)
|
||||
else:
|
||||
assert spangroup[1].label_ == "Thing"
|
||||
assert_almost_equal(0.8, spangroup.attrs["scores"][1], 5)
|
||||
|
||||
if nr_results > 2:
|
||||
assert spangroup[2].text == "Greater London"
|
||||
if max_positive == 2:
|
||||
assert spangroup[2].label_ == "GreatCity"
|
||||
assert_almost_equal(0.9, spangroup.attrs["scores"][2], 5)
|
||||
else:
|
||||
assert spangroup[2].label_ == "City"
|
||||
assert_almost_equal(0.7, spangroup.attrs["scores"][2], 5)
|
||||
|
||||
assert spangroup[-1].text == "Greater London"
|
||||
assert spangroup[-1].label_ == "GreatCity"
|
||||
assert_almost_equal(0.9, spangroup.attrs["scores"][-1], 5)
|
||||
|
||||
|
||||
def test_simple_train():
|
||||
fix_random_seed(0)
|
||||
nlp = Language()
|
||||
|
@ -90,6 +140,9 @@ def test_simple_train():
|
|||
scores = nlp.evaluate(get_examples())
|
||||
assert f"spans_{SPAN_KEY}_f" in scores
|
||||
assert scores[f"spans_{SPAN_KEY}_f"] == 1.0
|
||||
# also test that the spancat works for just a single entity in a sentence
|
||||
doc = nlp("London")
|
||||
assert len(doc.spans[spancat.key]) == 1
|
||||
|
||||
|
||||
def test_ngram_suggester(en_tokenizer):
|
||||
|
|
Loading…
Reference in New Issue
Block a user