Fix spancat training on nested entities (#9007)

* overfitting test on non-overlapping entities

* add failing overfitting test for overlapping entities

* failing test for list comprehension

* remove test that was put in separate PR

* bugfix

* cleanup
This commit is contained in:
Sofie Van Landeghem 2021-08-20 12:37:50 +02:00 committed by GitHub
parent 9cc3dc2b67
commit 4d52d7051c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 106 additions and 36 deletions

View File

@ -398,7 +398,7 @@ class SpanCategorizer(TrainablePipe):
pass pass
def _get_aligned_spans(self, eg: Example): def _get_aligned_spans(self, eg: Example):
return eg.get_aligned_spans_y2x(eg.reference.spans.get(self.key, [])) return eg.get_aligned_spans_y2x(eg.reference.spans.get(self.key, []), allow_overlap=True)
def _make_span_group( def _make_span_group(
self, doc: Doc, indices: Ints2d, scores: Floats2d, labels: List[str] self, doc: Doc, indices: Ints2d, scores: Floats2d, labels: List[str]

View File

@ -2,11 +2,14 @@ import pytest
import numpy import numpy
from numpy.testing import assert_array_equal, assert_almost_equal from numpy.testing import assert_array_equal, assert_almost_equal
from thinc.api import get_current_ops from thinc.api import get_current_ops
from spacy import util
from spacy.lang.en import English
from spacy.language import Language from spacy.language import Language
from spacy.tokens.doc import SpanGroups from spacy.tokens.doc import SpanGroups
from spacy.tokens import SpanGroup from spacy.tokens import SpanGroup
from spacy.training import Example from spacy.training import Example
from spacy.util import fix_random_seed, registry from spacy.util import fix_random_seed, registry, make_tempdir
OPS = get_current_ops() OPS = get_current_ops()
@ -20,18 +23,22 @@ TRAIN_DATA = [
), ),
] ]
TRAIN_DATA_OVERLAPPING = [
("Who is Shaka Khan?", {"spans": {SPAN_KEY: [(7, 17, "PERSON")]}}),
(
"I like London and Berlin",
{"spans": {SPAN_KEY: [(7, 13, "LOC"), (18, 24, "LOC"), (7, 24, "DOUBLE_LOC")]}},
),
]
def make_get_examples(nlp):
def make_examples(nlp, data=TRAIN_DATA):
train_examples = [] train_examples = []
for t in TRAIN_DATA: for t in data:
eg = Example.from_dict(nlp.make_doc(t[0]), t[1]) eg = Example.from_dict(nlp.make_doc(t[0]), t[1])
train_examples.append(eg) train_examples.append(eg)
def get_examples():
return train_examples return train_examples
return get_examples
def test_no_label(): def test_no_label():
nlp = Language() nlp = Language()
@ -57,9 +64,7 @@ def test_implicit_labels():
nlp = Language() nlp = Language()
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
assert len(spancat.labels) == 0 assert len(spancat.labels) == 0
train_examples = [] train_examples = make_examples(nlp)
for t in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
nlp.initialize(get_examples=lambda: train_examples) nlp.initialize(get_examples=lambda: train_examples)
assert spancat.labels == ("PERSON", "LOC") assert spancat.labels == ("PERSON", "LOC")
@ -140,30 +145,6 @@ def test_make_spangroup(max_positive, nr_results):
assert_almost_equal(0.9, spangroup.attrs["scores"][-1], 5) assert_almost_equal(0.9, spangroup.attrs["scores"][-1], 5)
def test_simple_train():
fix_random_seed(0)
nlp = Language()
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
get_examples = make_get_examples(nlp)
nlp.initialize(get_examples)
sgd = nlp.create_optimizer()
assert len(spancat.labels) != 0
for i in range(40):
losses = {}
nlp.update(list(get_examples()), losses=losses, drop=0.1, sgd=sgd)
doc = nlp("I like London and Berlin.")
assert doc.spans[spancat.key] == doc.spans[SPAN_KEY]
assert len(doc.spans[spancat.key]) == 2
assert len(doc.spans[spancat.key].attrs["scores"]) == 2
assert doc.spans[spancat.key][0].text == "London"
scores = nlp.evaluate(get_examples())
assert f"spans_{SPAN_KEY}_f" in scores
assert scores[f"spans_{SPAN_KEY}_f"] == 1.0
# also test that the spancat works for just a single entity in a sentence
doc = nlp("London")
assert len(doc.spans[spancat.key]) == 1
def test_ngram_suggester(en_tokenizer): def test_ngram_suggester(en_tokenizer):
# test different n-gram lengths # test different n-gram lengths
for size in [1, 2, 3]: for size in [1, 2, 3]:
@ -282,3 +263,92 @@ def test_ngram_sizes(en_tokenizer):
range_suggester = suggester_factory(min_size=2, max_size=4) range_suggester = suggester_factory(min_size=2, max_size=4)
ngrams_3 = range_suggester(docs) ngrams_3 = range_suggester(docs)
assert_array_equal(OPS.to_numpy(ngrams_3.lengths), [0, 1, 3, 6, 9]) assert_array_equal(OPS.to_numpy(ngrams_3.lengths), [0, 1, 3, 6, 9])
def test_overfitting_IO():
# Simple test to try and quickly overfit the spancat component - ensuring the ML models work correctly
fix_random_seed(0)
nlp = English()
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
train_examples = make_examples(nlp)
optimizer = nlp.initialize(get_examples=lambda: train_examples)
assert spancat.model.get_dim("nO") == 2
assert set(spancat.labels) == {"LOC", "PERSON"}
for i in range(50):
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
assert losses["spancat"] < 0.01
# test the trained model
test_text = "I like London and Berlin"
doc = nlp(test_text)
assert doc.spans[spancat.key] == doc.spans[SPAN_KEY]
spans = doc.spans[SPAN_KEY]
assert len(spans) == 2
assert len(spans.attrs["scores"]) == 2
assert min(spans.attrs["scores"]) > 0.9
assert set([span.text for span in spans]) == {"London", "Berlin"}
assert set([span.label_ for span in spans]) == {"LOC"}
# Also test the results are still the same after IO
with make_tempdir() as tmp_dir:
nlp.to_disk(tmp_dir)
nlp2 = util.load_model_from_path(tmp_dir)
doc2 = nlp2(test_text)
spans2 = doc2.spans[SPAN_KEY]
assert len(spans2) == 2
assert len(spans2.attrs["scores"]) == 2
assert min(spans2.attrs["scores"]) > 0.9
assert set([span.text for span in spans2]) == {"London", "Berlin"}
assert set([span.label_ for span in spans2]) == {"LOC"}
# Test scoring
scores = nlp.evaluate(train_examples)
assert f"spans_{SPAN_KEY}_f" in scores
assert scores[f"spans_{SPAN_KEY}_p"] == 1.0
assert scores[f"spans_{SPAN_KEY}_r"] == 1.0
assert scores[f"spans_{SPAN_KEY}_f"] == 1.0
# also test that the spancat works for just a single entity in a sentence
doc = nlp("London")
assert len(doc.spans[spancat.key]) == 1
def test_overfitting_IO_overlapping():
# Test for overfitting on overlapping entities
fix_random_seed(0)
nlp = English()
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
train_examples = make_examples(nlp, data=TRAIN_DATA_OVERLAPPING)
optimizer = nlp.initialize(get_examples=lambda: train_examples)
assert spancat.model.get_dim("nO") == 3
assert set(spancat.labels) == {"PERSON", "LOC", "DOUBLE_LOC"}
for i in range(50):
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
assert losses["spancat"] < 0.01
# test the trained model
test_text = "I like London and Berlin"
doc = nlp(test_text)
spans = doc.spans[SPAN_KEY]
assert len(spans) == 3
assert len(spans.attrs["scores"]) == 3
assert min(spans.attrs["scores"]) > 0.9
assert set([span.text for span in spans]) == {"London", "Berlin", "London and Berlin"}
assert set([span.label_ for span in spans]) == {"LOC", "DOUBLE_LOC"}
# Also test the results are still the same after IO
with make_tempdir() as tmp_dir:
nlp.to_disk(tmp_dir)
nlp2 = util.load_model_from_path(tmp_dir)
doc2 = nlp2(test_text)
spans2 = doc2.spans[SPAN_KEY]
assert len(spans2) == 3
assert len(spans2.attrs["scores"]) == 3
assert min(spans2.attrs["scores"]) > 0.9
assert set([span.text for span in spans2]) == {"London", "Berlin", "London and Berlin"}
assert set([span.label_ for span in spans2]) == {"LOC", "DOUBLE_LOC"}