mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 10:46:29 +03:00
Fixes in test suite (#6457)
* fix slow test for textcat readers * cleanup test_issue5551 * add explicit score weight * cleanup
This commit is contained in:
parent
31ec9a906e
commit
d6c616a125
|
@ -1,35 +1,38 @@
|
|||
from thinc.api import fix_random_seed
|
||||
import pytest
|
||||
from thinc.api import Config, fix_random_seed
|
||||
|
||||
from spacy.lang.en import English
|
||||
from spacy.pipeline.textcat import default_model_config, bow_model_config
|
||||
from spacy.pipeline.textcat import cnn_model_config
|
||||
from spacy.tokens import Span
|
||||
from spacy import displacy
|
||||
from spacy.pipeline import merge_entities
|
||||
from spacy.training import Example
|
||||
|
||||
|
||||
def test_issue5551():
|
||||
@pytest.mark.parametrize(
|
||||
"textcat_config", [default_model_config, bow_model_config, cnn_model_config]
|
||||
)
|
||||
def test_issue5551(textcat_config):
|
||||
"""Test that after fixing the random seed, the results of the pipeline are truly identical"""
|
||||
component = "textcat"
|
||||
pipe_cfg = {
|
||||
"model": {
|
||||
"@architectures": "spacy.TextCatBOW.v1",
|
||||
"exclusive_classes": True,
|
||||
"ngram_size": 2,
|
||||
"no_output_layer": False,
|
||||
}
|
||||
}
|
||||
|
||||
pipe_cfg = Config().from_str(textcat_config)
|
||||
results = []
|
||||
for i in range(3):
|
||||
fix_random_seed(0)
|
||||
nlp = English()
|
||||
example = (
|
||||
"Once hot, form ping-pong-ball-sized balls of the mixture, each weighing roughly 25 g.",
|
||||
{"cats": {"Labe1": 1.0, "Label2": 0.0, "Label3": 0.0}},
|
||||
)
|
||||
text = "Once hot, form ping-pong-ball-sized balls of the mixture, each weighing roughly 25 g."
|
||||
annots = {"cats": {"Labe1": 1.0, "Label2": 0.0, "Label3": 0.0}}
|
||||
pipe = nlp.add_pipe(component, config=pipe_cfg, last=True)
|
||||
for label in set(example[1]["cats"]):
|
||||
for label in set(annots["cats"]):
|
||||
pipe.add_label(label)
|
||||
# Train
|
||||
nlp.initialize()
|
||||
doc = nlp.make_doc(text)
|
||||
nlp.update([Example.from_dict(doc, annots)])
|
||||
# Store the result of each iteration
|
||||
result = pipe.model.predict([nlp.make_doc(example[0])])
|
||||
result = pipe.model.predict([doc])
|
||||
results.append(list(result[0]))
|
||||
# All results should be the same because of the fixed seed
|
||||
assert len(results) == 3
|
||||
|
|
|
@ -72,6 +72,10 @@ def test_readers():
|
|||
def test_cat_readers(reader, additional_config):
|
||||
nlp_config_string = """
|
||||
[training]
|
||||
seed = 0
|
||||
|
||||
[training.score_weights]
|
||||
cats_macro_auc = 1.0
|
||||
|
||||
[corpora]
|
||||
@readers = "PLACEHOLDER"
|
||||
|
@ -92,9 +96,7 @@ def test_cat_readers(reader, additional_config):
|
|||
config["corpora"]["@readers"] = reader
|
||||
config["corpora"].update(additional_config)
|
||||
nlp = load_model_from_config(config, auto_fill=True)
|
||||
T = registry.resolve(
|
||||
nlp.config["training"].interpolate(), schema=ConfigSchemaTraining
|
||||
)
|
||||
T = registry.resolve(nlp.config["training"], schema=ConfigSchemaTraining)
|
||||
dot_names = [T["train_corpus"], T["dev_corpus"]]
|
||||
train_corpus, dev_corpus = resolve_dot_names(nlp.config, dot_names)
|
||||
optimizer = T["optimizer"]
|
||||
|
|
Loading…
Reference in New Issue
Block a user