mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-21 17:41:59 +03:00
failing overfit test
This commit is contained in:
parent
f599bd5a4d
commit
90af16af76
|
@ -3,12 +3,40 @@ from thinc.api import Config
|
|||
from thinc.types import Ragged
|
||||
|
||||
from spacy.language import Language
|
||||
from spacy.lang.en import English
|
||||
from spacy.pipeline.span_finder import DEFAULT_PREDICTED_KEY, span_finder_default_config
|
||||
from spacy.tokens import Doc
|
||||
from spacy.training import Example
|
||||
from spacy import util
|
||||
from spacy.util import registry
|
||||
from spacy.util import fix_random_seed, make_tempdir
|
||||
|
||||
|
||||
TRAINING_KEY = "pytest"
|
||||
TRAIN_DATA = [
|
||||
("Who is Shaka Khan?", {"spans": {TRAINING_KEY: [(7, 17)]}}),
|
||||
(
|
||||
"I like London and Berlin.",
|
||||
{"spans": {TRAINING_KEY: [(7, 13, "LOC"), (18, 24)]}},
|
||||
),
|
||||
]
|
||||
|
||||
TRAIN_DATA_OVERLAPPING = [
|
||||
("Who is Shaka Khan?", {"spans": {TRAINING_KEY: [(7, 17)]}}),
|
||||
(
|
||||
"I like London and Berlin",
|
||||
{"spans": {TRAINING_KEY: [(7, 13), (18, 24), (7, 24)]}},
|
||||
),
|
||||
("", {"spans": {TRAINING_KEY: []}}),
|
||||
]
|
||||
|
||||
|
||||
def make_examples(nlp, data=TRAIN_DATA):
|
||||
train_examples = []
|
||||
for t in data:
|
||||
eg = Example.from_dict(nlp.make_doc(t[0]), t[1])
|
||||
train_examples.append(eg)
|
||||
return train_examples
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -191,3 +219,50 @@ def test_span_finder_suggester():
|
|||
assert span_length == len(candidates.dataXd)
|
||||
assert type(candidates) == Ragged
|
||||
assert len(candidates.dataXd[0]) == 2
|
||||
|
||||
|
||||
# XXX Fails because i think the suggester is not correctly implemented?
|
||||
def test_overfitting_IO():
|
||||
# Simple test to try and quickly overfit the spancat component - ensuring the ML models work correctly
|
||||
fix_random_seed(0)
|
||||
nlp = English()
|
||||
span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY})
|
||||
train_examples = make_examples(nlp)
|
||||
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||
assert span_finder.model.get_dim("nO") == 2
|
||||
|
||||
for i in range(50):
|
||||
losses = {}
|
||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||
assert losses["span_finder"] < 0.001
|
||||
|
||||
# test the trained model
|
||||
test_text = "I like London and Berlin"
|
||||
doc = nlp(test_text)
|
||||
spans = doc.spans[span_finder.predicted_key]
|
||||
assert len(spans) == 2
|
||||
assert len(spans.attrs["scores"]) == 2
|
||||
assert min(spans.attrs["scores"]) > 0.9
|
||||
assert set([span.text for span in spans]) == {"London", "Berlin"}
|
||||
|
||||
# Also test the results are still the same after IO
|
||||
with make_tempdir() as tmp_dir:
|
||||
nlp.to_disk(tmp_dir)
|
||||
nlp2 = util.load_model_from_path(tmp_dir)
|
||||
doc2 = nlp2(test_text)
|
||||
spans2 = doc2.spans[TRAINING_KEY]
|
||||
assert len(spans2) == 2
|
||||
assert len(spans2.attrs["scores"]) == 2
|
||||
assert min(spans2.attrs["scores"]) > 0.9
|
||||
assert set([span.text for span in spans2]) == {"London", "Berlin"}
|
||||
|
||||
# Test scoring
|
||||
scores = nlp.evaluate(train_examples)
|
||||
assert f"spans_{TRAINING_KEY}_f" in scores
|
||||
assert scores[f"spans_{TRAINING_KEY}_p"] == 1.0
|
||||
assert scores[f"spans_{TRAINING_KEY}_r"] == 1.0
|
||||
assert scores[f"spans_{TRAINING_KEY}_f"] == 1.0
|
||||
|
||||
# also test that the spancat works for just a single entity in a sentence
|
||||
doc = nlp("London")
|
||||
assert len(doc.spans[span_finder.predicted_key]) == 1
|
||||
|
|
Loading…
Reference in New Issue
Block a user