diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py index 3bc2d98a4..4cdaf3d83 100644 --- a/spacy/pipeline/spancat.py +++ b/spacy/pipeline/spancat.py @@ -398,7 +398,7 @@ class SpanCategorizer(TrainablePipe): pass def _get_aligned_spans(self, eg: Example): - return eg.get_aligned_spans_y2x(eg.reference.spans.get(self.key, [])) + return eg.get_aligned_spans_y2x(eg.reference.spans.get(self.key, []), allow_overlap=True) def _make_span_group( self, doc: Doc, indices: Ints2d, scores: Floats2d, labels: List[str] diff --git a/spacy/tests/pipeline/test_spancat.py b/spacy/tests/pipeline/test_spancat.py index 974994372..3da5816ab 100644 --- a/spacy/tests/pipeline/test_spancat.py +++ b/spacy/tests/pipeline/test_spancat.py @@ -2,11 +2,14 @@ import pytest import numpy from numpy.testing import assert_array_equal, assert_almost_equal from thinc.api import get_current_ops + +from spacy import util +from spacy.lang.en import English from spacy.language import Language from spacy.tokens.doc import SpanGroups from spacy.tokens import SpanGroup from spacy.training import Example -from spacy.util import fix_random_seed, registry +from spacy.util import fix_random_seed, registry, make_tempdir OPS = get_current_ops() @@ -20,17 +23,21 @@ TRAIN_DATA = [ ), ] +TRAIN_DATA_OVERLAPPING = [ + ("Who is Shaka Khan?", {"spans": {SPAN_KEY: [(7, 17, "PERSON")]}}), + ( + "I like London and Berlin", + {"spans": {SPAN_KEY: [(7, 13, "LOC"), (18, 24, "LOC"), (7, 24, "DOUBLE_LOC")]}}, + ), +] -def make_get_examples(nlp): + +def make_examples(nlp, data=TRAIN_DATA): train_examples = [] - for t in TRAIN_DATA: + for t in data: eg = Example.from_dict(nlp.make_doc(t[0]), t[1]) train_examples.append(eg) - - def get_examples(): - return train_examples - - return get_examples + return train_examples def test_no_label(): @@ -57,9 +64,7 @@ def test_implicit_labels(): nlp = Language() spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) assert len(spancat.labels) == 0 - train_examples = [] - for t in TRAIN_DATA: - train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1])) + train_examples = make_examples(nlp) nlp.initialize(get_examples=lambda: train_examples) assert spancat.labels == ("PERSON", "LOC") @@ -140,30 +145,6 @@ def test_make_spangroup(max_positive, nr_results): assert_almost_equal(0.9, spangroup.attrs["scores"][-1], 5) -def test_simple_train(): - fix_random_seed(0) - nlp = Language() - spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) - get_examples = make_get_examples(nlp) - nlp.initialize(get_examples) - sgd = nlp.create_optimizer() - assert len(spancat.labels) != 0 - for i in range(40): - losses = {} - nlp.update(list(get_examples()), losses=losses, drop=0.1, sgd=sgd) - doc = nlp("I like London and Berlin.") - assert doc.spans[spancat.key] == doc.spans[SPAN_KEY] - assert len(doc.spans[spancat.key]) == 2 - assert len(doc.spans[spancat.key].attrs["scores"]) == 2 - assert doc.spans[spancat.key][0].text == "London" - scores = nlp.evaluate(get_examples()) - assert f"spans_{SPAN_KEY}_f" in scores - assert scores[f"spans_{SPAN_KEY}_f"] == 1.0 - # also test that the spancat works for just a single entity in a sentence - doc = nlp("London") - assert len(doc.spans[spancat.key]) == 1 - - def test_ngram_suggester(en_tokenizer): # test different n-gram lengths for size in [1, 2, 3]: @@ -282,3 +263,92 @@ def test_ngram_sizes(en_tokenizer): range_suggester = suggester_factory(min_size=2, max_size=4) ngrams_3 = range_suggester(docs) assert_array_equal(OPS.to_numpy(ngrams_3.lengths), [0, 1, 3, 6, 9]) + + +def test_overfitting_IO(): + # Simple test to try and quickly overfit the spancat component - ensuring the ML models work correctly + fix_random_seed(0) + nlp = English() + spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) + train_examples = make_examples(nlp) + optimizer = nlp.initialize(get_examples=lambda: train_examples) + assert spancat.model.get_dim("nO") == 2 + assert set(spancat.labels) == {"LOC", "PERSON"} + + for i in range(50): + losses = {} + nlp.update(train_examples, sgd=optimizer, losses=losses) + assert losses["spancat"] < 0.01 + + # test the trained model + test_text = "I like London and Berlin" + doc = nlp(test_text) + assert doc.spans[spancat.key] == doc.spans[SPAN_KEY] + spans = doc.spans[SPAN_KEY] + assert len(spans) == 2 + assert len(spans.attrs["scores"]) == 2 + assert min(spans.attrs["scores"]) > 0.9 + assert set([span.text for span in spans]) == {"London", "Berlin"} + assert set([span.label_ for span in spans]) == {"LOC"} + + # Also test the results are still the same after IO + with make_tempdir() as tmp_dir: + nlp.to_disk(tmp_dir) + nlp2 = util.load_model_from_path(tmp_dir) + doc2 = nlp2(test_text) + spans2 = doc2.spans[SPAN_KEY] + assert len(spans2) == 2 + assert len(spans2.attrs["scores"]) == 2 + assert min(spans2.attrs["scores"]) > 0.9 + assert set([span.text for span in spans2]) == {"London", "Berlin"} + assert set([span.label_ for span in spans2]) == {"LOC"} + + # Test scoring + scores = nlp.evaluate(train_examples) + assert f"spans_{SPAN_KEY}_f" in scores + assert scores[f"spans_{SPAN_KEY}_p"] == 1.0 + assert scores[f"spans_{SPAN_KEY}_r"] == 1.0 + assert scores[f"spans_{SPAN_KEY}_f"] == 1.0 + + # also test that the spancat works for just a single entity in a sentence + doc = nlp("London") + assert len(doc.spans[spancat.key]) == 1 + + +def test_overfitting_IO_overlapping(): + # Test for overfitting on overlapping entities + fix_random_seed(0) + nlp = English() + spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) + + train_examples = make_examples(nlp, data=TRAIN_DATA_OVERLAPPING) + optimizer = nlp.initialize(get_examples=lambda: train_examples) + assert spancat.model.get_dim("nO") == 3 + assert set(spancat.labels) == {"PERSON", "LOC", "DOUBLE_LOC"} + + for i in range(50): + losses = {} + nlp.update(train_examples, sgd=optimizer, losses=losses) + assert losses["spancat"] < 0.01 + + # test the trained model + test_text = "I like London and Berlin" + doc = nlp(test_text) + spans = doc.spans[SPAN_KEY] + assert len(spans) == 3 + assert len(spans.attrs["scores"]) == 3 + assert min(spans.attrs["scores"]) > 0.9 + assert set([span.text for span in spans]) == {"London", "Berlin", "London and Berlin"} + assert set([span.label_ for span in spans]) == {"LOC", "DOUBLE_LOC"} + + # Also test the results are still the same after IO + with make_tempdir() as tmp_dir: + nlp.to_disk(tmp_dir) + nlp2 = util.load_model_from_path(tmp_dir) + doc2 = nlp2(test_text) + spans2 = doc2.spans[SPAN_KEY] + assert len(spans2) == 3 + assert len(spans2.attrs["scores"]) == 3 + assert min(spans2.attrs["scores"]) > 0.9 + assert set([span.text for span in spans2]) == {"London", "Berlin", "London and Berlin"} + assert set([span.label_ for span in spans2]) == {"LOC", "DOUBLE_LOC"}