mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 16:22:29 +03:00
failing overfit test
This commit is contained in:
parent
f599bd5a4d
commit
90af16af76
|
@ -3,12 +3,40 @@ from thinc.api import Config
|
||||||
from thinc.types import Ragged
|
from thinc.types import Ragged
|
||||||
|
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
|
from spacy.lang.en import English
|
||||||
from spacy.pipeline.span_finder import DEFAULT_PREDICTED_KEY, span_finder_default_config
|
from spacy.pipeline.span_finder import DEFAULT_PREDICTED_KEY, span_finder_default_config
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
from spacy.training import Example
|
from spacy.training import Example
|
||||||
|
from spacy import util
|
||||||
from spacy.util import registry
|
from spacy.util import registry
|
||||||
|
from spacy.util import fix_random_seed, make_tempdir
|
||||||
|
|
||||||
|
|
||||||
TRAINING_KEY = "pytest"
|
TRAINING_KEY = "pytest"
|
||||||
|
TRAIN_DATA = [
|
||||||
|
("Who is Shaka Khan?", {"spans": {TRAINING_KEY: [(7, 17)]}}),
|
||||||
|
(
|
||||||
|
"I like London and Berlin.",
|
||||||
|
{"spans": {TRAINING_KEY: [(7, 13, "LOC"), (18, 24)]}},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
TRAIN_DATA_OVERLAPPING = [
|
||||||
|
("Who is Shaka Khan?", {"spans": {TRAINING_KEY: [(7, 17)]}}),
|
||||||
|
(
|
||||||
|
"I like London and Berlin",
|
||||||
|
{"spans": {TRAINING_KEY: [(7, 13), (18, 24), (7, 24)]}},
|
||||||
|
),
|
||||||
|
("", {"spans": {TRAINING_KEY: []}}),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def make_examples(nlp, data=TRAIN_DATA):
|
||||||
|
train_examples = []
|
||||||
|
for t in data:
|
||||||
|
eg = Example.from_dict(nlp.make_doc(t[0]), t[1])
|
||||||
|
train_examples.append(eg)
|
||||||
|
return train_examples
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -191,3 +219,50 @@ def test_span_finder_suggester():
|
||||||
assert span_length == len(candidates.dataXd)
|
assert span_length == len(candidates.dataXd)
|
||||||
assert type(candidates) == Ragged
|
assert type(candidates) == Ragged
|
||||||
assert len(candidates.dataXd[0]) == 2
|
assert len(candidates.dataXd[0]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
# XXX Fails because i think the suggester is not correctly implemented?
|
||||||
|
def test_overfitting_IO():
|
||||||
|
# Simple test to try and quickly overfit the spancat component - ensuring the ML models work correctly
|
||||||
|
fix_random_seed(0)
|
||||||
|
nlp = English()
|
||||||
|
span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY})
|
||||||
|
train_examples = make_examples(nlp)
|
||||||
|
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||||
|
assert span_finder.model.get_dim("nO") == 2
|
||||||
|
|
||||||
|
for i in range(50):
|
||||||
|
losses = {}
|
||||||
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
|
assert losses["span_finder"] < 0.001
|
||||||
|
|
||||||
|
# test the trained model
|
||||||
|
test_text = "I like London and Berlin"
|
||||||
|
doc = nlp(test_text)
|
||||||
|
spans = doc.spans[span_finder.predicted_key]
|
||||||
|
assert len(spans) == 2
|
||||||
|
assert len(spans.attrs["scores"]) == 2
|
||||||
|
assert min(spans.attrs["scores"]) > 0.9
|
||||||
|
assert set([span.text for span in spans]) == {"London", "Berlin"}
|
||||||
|
|
||||||
|
# Also test the results are still the same after IO
|
||||||
|
with make_tempdir() as tmp_dir:
|
||||||
|
nlp.to_disk(tmp_dir)
|
||||||
|
nlp2 = util.load_model_from_path(tmp_dir)
|
||||||
|
doc2 = nlp2(test_text)
|
||||||
|
spans2 = doc2.spans[TRAINING_KEY]
|
||||||
|
assert len(spans2) == 2
|
||||||
|
assert len(spans2.attrs["scores"]) == 2
|
||||||
|
assert min(spans2.attrs["scores"]) > 0.9
|
||||||
|
assert set([span.text for span in spans2]) == {"London", "Berlin"}
|
||||||
|
|
||||||
|
# Test scoring
|
||||||
|
scores = nlp.evaluate(train_examples)
|
||||||
|
assert f"spans_{TRAINING_KEY}_f" in scores
|
||||||
|
assert scores[f"spans_{TRAINING_KEY}_p"] == 1.0
|
||||||
|
assert scores[f"spans_{TRAINING_KEY}_r"] == 1.0
|
||||||
|
assert scores[f"spans_{TRAINING_KEY}_f"] == 1.0
|
||||||
|
|
||||||
|
# also test that the spancat works for just a single entity in a sentence
|
||||||
|
doc = nlp("London")
|
||||||
|
assert len(doc.spans[span_finder.predicted_key]) == 1
|
||||||
|
|
Loading…
Reference in New Issue
Block a user