failing overfit test

This commit is contained in:
kadarakos 2023-05-31 17:30:56 +00:00
parent f599bd5a4d
commit 90af16af76

View File

@ -3,12 +3,40 @@ from thinc.api import Config
from thinc.types import Ragged
from spacy.language import Language
from spacy.lang.en import English
from spacy.pipeline.span_finder import DEFAULT_PREDICTED_KEY, span_finder_default_config
from spacy.tokens import Doc
from spacy.training import Example
from spacy import util
from spacy.util import registry
from spacy.util import fix_random_seed, make_tempdir
TRAINING_KEY = "pytest"
TRAIN_DATA = [
("Who is Shaka Khan?", {"spans": {TRAINING_KEY: [(7, 17)]}}),
(
"I like London and Berlin.",
{"spans": {TRAINING_KEY: [(7, 13, "LOC"), (18, 24)]}},
),
]
TRAIN_DATA_OVERLAPPING = [
("Who is Shaka Khan?", {"spans": {TRAINING_KEY: [(7, 17)]}}),
(
"I like London and Berlin",
{"spans": {TRAINING_KEY: [(7, 13), (18, 24), (7, 24)]}},
),
("", {"spans": {TRAINING_KEY: []}}),
]
def make_examples(nlp, data=TRAIN_DATA):
train_examples = []
for t in data:
eg = Example.from_dict(nlp.make_doc(t[0]), t[1])
train_examples.append(eg)
return train_examples
@pytest.mark.parametrize(
@ -191,3 +219,50 @@ def test_span_finder_suggester():
assert span_length == len(candidates.dataXd)
assert type(candidates) == Ragged
assert len(candidates.dataXd[0]) == 2
# XXX Fails because i think the suggester is not correctly implemented?
def test_overfitting_IO():
# Simple test to try and quickly overfit the spancat component - ensuring the ML models work correctly
fix_random_seed(0)
nlp = English()
span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY})
train_examples = make_examples(nlp)
optimizer = nlp.initialize(get_examples=lambda: train_examples)
assert span_finder.model.get_dim("nO") == 2
for i in range(50):
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
assert losses["span_finder"] < 0.001
# test the trained model
test_text = "I like London and Berlin"
doc = nlp(test_text)
spans = doc.spans[span_finder.predicted_key]
assert len(spans) == 2
assert len(spans.attrs["scores"]) == 2
assert min(spans.attrs["scores"]) > 0.9
assert set([span.text for span in spans]) == {"London", "Berlin"}
# Also test the results are still the same after IO
with make_tempdir() as tmp_dir:
nlp.to_disk(tmp_dir)
nlp2 = util.load_model_from_path(tmp_dir)
doc2 = nlp2(test_text)
spans2 = doc2.spans[TRAINING_KEY]
assert len(spans2) == 2
assert len(spans2.attrs["scores"]) == 2
assert min(spans2.attrs["scores"]) > 0.9
assert set([span.text for span in spans2]) == {"London", "Berlin"}
# Test scoring
scores = nlp.evaluate(train_examples)
assert f"spans_{TRAINING_KEY}_f" in scores
assert scores[f"spans_{TRAINING_KEY}_p"] == 1.0
assert scores[f"spans_{TRAINING_KEY}_r"] == 1.0
assert scores[f"spans_{TRAINING_KEY}_f"] == 1.0
# also test that the spancat works for just a single entity in a sentence
doc = nlp("London")
assert len(doc.spans[span_finder.predicted_key]) == 1