diff --git a/spacy/tests/pipeline/test_span_finder.py b/spacy/tests/pipeline/test_span_finder.py index 55d126ecc..81e2ae1e2 100644 --- a/spacy/tests/pipeline/test_span_finder.py +++ b/spacy/tests/pipeline/test_span_finder.py @@ -3,12 +3,40 @@ from thinc.api import Config from thinc.types import Ragged from spacy.language import Language +from spacy.lang.en import English from spacy.pipeline.span_finder import DEFAULT_PREDICTED_KEY, span_finder_default_config from spacy.tokens import Doc from spacy.training import Example +from spacy import util from spacy.util import registry +from spacy.util import fix_random_seed, make_tempdir + TRAINING_KEY = "pytest" +TRAIN_DATA = [ + ("Who is Shaka Khan?", {"spans": {TRAINING_KEY: [(7, 17)]}}), + ( + "I like London and Berlin.", + {"spans": {TRAINING_KEY: [(7, 13, "LOC"), (18, 24)]}}, + ), +] + +TRAIN_DATA_OVERLAPPING = [ + ("Who is Shaka Khan?", {"spans": {TRAINING_KEY: [(7, 17)]}}), + ( + "I like London and Berlin", + {"spans": {TRAINING_KEY: [(7, 13), (18, 24), (7, 24)]}}, + ), + ("", {"spans": {TRAINING_KEY: []}}), +] + + +def make_examples(nlp, data=TRAIN_DATA): + train_examples = [] + for t in data: + eg = Example.from_dict(nlp.make_doc(t[0]), t[1]) + train_examples.append(eg) + return train_examples @pytest.mark.parametrize( @@ -191,3 +219,50 @@ def test_span_finder_suggester(): assert span_length == len(candidates.dataXd) assert type(candidates) == Ragged assert len(candidates.dataXd[0]) == 2 + + +# XXX Fails because i think the suggester is not correctly implemented? +def test_overfitting_IO(): + # Simple test to try and quickly overfit the spancat component - ensuring the ML models work correctly + fix_random_seed(0) + nlp = English() + span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY}) + train_examples = make_examples(nlp) + optimizer = nlp.initialize(get_examples=lambda: train_examples) + assert span_finder.model.get_dim("nO") == 2 + + for i in range(50): + losses = {} + nlp.update(train_examples, sgd=optimizer, losses=losses) + assert losses["span_finder"] < 0.001 + + # test the trained model + test_text = "I like London and Berlin" + doc = nlp(test_text) + spans = doc.spans[span_finder.predicted_key] + assert len(spans) == 2 + assert len(spans.attrs["scores"]) == 2 + assert min(spans.attrs["scores"]) > 0.9 + assert set([span.text for span in spans]) == {"London", "Berlin"} + + # Also test the results are still the same after IO + with make_tempdir() as tmp_dir: + nlp.to_disk(tmp_dir) + nlp2 = util.load_model_from_path(tmp_dir) + doc2 = nlp2(test_text) + spans2 = doc2.spans[TRAINING_KEY] + assert len(spans2) == 2 + assert len(spans2.attrs["scores"]) == 2 + assert min(spans2.attrs["scores"]) > 0.9 + assert set([span.text for span in spans2]) == {"London", "Berlin"} + + # Test scoring + scores = nlp.evaluate(train_examples) + assert f"spans_{TRAINING_KEY}_f" in scores + assert scores[f"spans_{TRAINING_KEY}_p"] == 1.0 + assert scores[f"spans_{TRAINING_KEY}_r"] == 1.0 + assert scores[f"spans_{TRAINING_KEY}_f"] == 1.0 + + # also test that the spancat works for just a single entity in a sentence + doc = nlp("London") + assert len(doc.spans[span_finder.predicted_key]) == 1