Fix overfitting test (#6011)

* remove unused MORPH_RULES

* fix textcat architecture in overfitting test
This commit is contained in:
Sofie Van Landeghem 2020-09-02 13:07:41 +02:00 committed by GitHub
parent b97d98783a
commit eb56377799
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 9 deletions

View File

@ -28,8 +28,6 @@ def test_tagger_begin_training_tag_map():
TAGS = ("N", "V", "J")
MORPH_RULES = {"V": {"like": {"lemma": "luck"}}}
TRAIN_DATA = [
("I like green eggs", {"tags": ["N", "V", "J", "N"]}),
("Eat blue ham", {"tags": ["V", "J", "N"]}),

View File

@ -84,9 +84,8 @@ def test_overfitting_IO():
# Simple test to try and quickly overfit the textcat component - ensuring the ML models work correctly
fix_random_seed(0)
nlp = English()
textcat = nlp.add_pipe("textcat")
# Set exclusive labels
textcat.model.attrs["multi_label"] = False
textcat = nlp.add_pipe("textcat", config={"model": {"exclusive_classes": True}})
train_examples = []
for text, annotations in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
@ -103,9 +102,8 @@ def test_overfitting_IO():
test_text = "I am happy."
doc = nlp(test_text)
cats = doc.cats
# note that by default, exclusive_classes = false so we need a bigger error margin
assert cats["POSITIVE"] > 0.8
assert cats["POSITIVE"] + cats["NEGATIVE"] == pytest.approx(1.0, 0.1)
assert cats["POSITIVE"] > 0.9
assert cats["POSITIVE"] + cats["NEGATIVE"] == pytest.approx(1.0, 0.001)
# Also test the results are still the same after IO
with make_tempdir() as tmp_dir:
@ -113,8 +111,8 @@ def test_overfitting_IO():
nlp2 = util.load_model_from_path(tmp_dir)
doc2 = nlp2(test_text)
cats2 = doc2.cats
assert cats2["POSITIVE"] > 0.8
assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.1)
assert cats2["POSITIVE"] > 0.9
assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.001)
# Test scoring
scores = nlp.evaluate(train_examples, scorer_cfg={"positive_label": "POSITIVE"})