Fixes in test suite (#6457)

* fix slow test for textcat readers

* cleanup test_issue5551

* add explicit score weight

* cleanup
This commit is contained in:
Sofie Van Landeghem 2020-12-02 12:57:08 +01:00 committed by GitHub
parent 31ec9a906e
commit d6c616a125
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 19 deletions

View File

@ -1,35 +1,38 @@
from thinc.api import fix_random_seed
import pytest
from thinc.api import Config, fix_random_seed
from spacy.lang.en import English
from spacy.pipeline.textcat import default_model_config, bow_model_config
from spacy.pipeline.textcat import cnn_model_config
from spacy.tokens import Span
from spacy import displacy
from spacy.pipeline import merge_entities
from spacy.training import Example
def test_issue5551():
@pytest.mark.parametrize(
"textcat_config", [default_model_config, bow_model_config, cnn_model_config]
)
def test_issue5551(textcat_config):
"""Test that after fixing the random seed, the results of the pipeline are truly identical"""
component = "textcat"
pipe_cfg = {
"model": {
"@architectures": "spacy.TextCatBOW.v1",
"exclusive_classes": True,
"ngram_size": 2,
"no_output_layer": False,
}
}
pipe_cfg = Config().from_str(textcat_config)
results = []
for i in range(3):
fix_random_seed(0)
nlp = English()
example = (
"Once hot, form ping-pong-ball-sized balls of the mixture, each weighing roughly 25 g.",
{"cats": {"Labe1": 1.0, "Label2": 0.0, "Label3": 0.0}},
)
text = "Once hot, form ping-pong-ball-sized balls of the mixture, each weighing roughly 25 g."
annots = {"cats": {"Labe1": 1.0, "Label2": 0.0, "Label3": 0.0}}
pipe = nlp.add_pipe(component, config=pipe_cfg, last=True)
for label in set(example[1]["cats"]):
for label in set(annots["cats"]):
pipe.add_label(label)
# Train
nlp.initialize()
doc = nlp.make_doc(text)
nlp.update([Example.from_dict(doc, annots)])
# Store the result of each iteration
result = pipe.model.predict([nlp.make_doc(example[0])])
result = pipe.model.predict([doc])
results.append(list(result[0]))
# All results should be the same because of the fixed seed
assert len(results) == 3

View File

@ -72,6 +72,10 @@ def test_readers():
def test_cat_readers(reader, additional_config):
nlp_config_string = """
[training]
seed = 0
[training.score_weights]
cats_macro_auc = 1.0
[corpora]
@readers = "PLACEHOLDER"
@ -92,9 +96,7 @@ def test_cat_readers(reader, additional_config):
config["corpora"]["@readers"] = reader
config["corpora"].update(additional_config)
nlp = load_model_from_config(config, auto_fill=True)
T = registry.resolve(
nlp.config["training"].interpolate(), schema=ConfigSchemaTraining
)
T = registry.resolve(nlp.config["training"], schema=ConfigSchemaTraining)
dot_names = [T["train_corpus"], T["dev_corpus"]]
train_corpus, dev_corpus = resolve_dot_names(nlp.config, dot_names)
optimizer = T["optimizer"]