Improve unit tests readability

This commit is contained in:
thomashacker 2023-01-05 11:58:33 +01:00
parent 3279e7051a
commit 97a9c03398

View File

@ -202,10 +202,10 @@ depth = 4
window_size = 1 window_size = 1
""" """
TEXTCAT_LISTENER_CONFIG = """ NER_LISTENER_CONFIG = """
[nlp] [nlp]
lang = "en" lang = "en"
pipeline = ["tok2vec","textcat"] pipeline = ["tok2vec","ner"]
batch_size = 1000 batch_size = 1000
[components] [components]
@ -221,86 +221,32 @@ factory = "tok2vec"
width = ${components.tok2vec.model.encode.width} width = ${components.tok2vec.model.encode.width}
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"] attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
rows = [5000, 1000, 2500, 2500] rows = [5000, 1000, 2500, 2500]
include_static_vectors = true include_static_vectors = false
[components.tok2vec.model.encode] [components.tok2vec.model.encode]
@architectures = "spacy.MaxoutWindowEncoder.v2" @architectures = "spacy.MaxoutWindowEncoder.v2"
width = 256 width = 96
depth = 8 depth = 4
window_size = 1 window_size = 1
maxout_pieces = 3 maxout_pieces = 3
[components.textcat] [components.ner]
factory = "textcat" factory = "ner"
[components.textcat.model] [components.ner.model]
@architectures = "spacy.TextCatEnsemble.v2" @architectures = "spacy.TransitionBasedParser.v2"
state_type = "ner"
extra_state_tokens = false
hidden_width = 64
maxout_pieces = 2
use_upper = true
nO = null nO = null
[components.textcat.model.tok2vec] [components.ner.model.tok2vec]
@architectures = "spacy.Tok2VecListener.v1" @architectures = "spacy.Tok2VecListener.v1"
width = ${components.tok2vec.model.encode.width} width = ${components.tok2vec.model.encode.width}
[components.textcat.model.linear_model]
@architectures = "spacy.TextCatBOW.v2"
exclusive_classes = true
ngram_size = 1
no_output_layer = false
""" """
TEXTCAT_CONFIG = """
[nlp]
lang = "en"
pipeline = ["textcat"]
disabled = []
before_creation = null
after_creation = null
after_pipeline_creation = null
batch_size = 1000
tokenizer = {"@tokenizers":"spacy.Tokenizer.v1"}
[components]
[components.textcat]
factory = "textcat"
threshold = 0.5
[components.textcat.model]
@architectures = "spacy.TextCatEnsemble.v2"
nO = null
[components.textcat.model.linear_model]
@architectures = "spacy.TextCatBOW.v1"
exclusive_classes = true
ngram_size = 1
no_output_layer = false
nO = null
[components.textcat.model.tok2vec]
@architectures = "spacy.Tok2Vec.v2"
[components.textcat.model.tok2vec.embed]
@architectures = "spacy.MultiHashEmbed.v1"
width = 64
rows = [2000,2000,1000,1000,1000,1000]
attrs = ["ORTH","LOWER","PREFIX","SUFFIX","SHAPE","ID"]
include_static_vectors = false
[components.textcat.model.tok2vec.encode]
@architectures = "spacy.MaxoutWindowEncoder.v2"
width = 64
window_size = 1
maxout_pieces = 3
depth = 2
"""
TEXTCAT_EXAMPLE_TEXTS = [
("This is a sentence for LABEL_A.", {"cats": {"LABEL_A": 1, "LABEL_B": 0}}),
("A sentence for the label LABEL_B.", {"cats": {"LABEL_A": 0, "LABEL_B": 1}}),
]
TEXTCAT_LABELS = ["LABEL_A", "LABEL_B"]
def _add_ner_label(ner, data): def _add_ner_label(ner, data):
for _, annotations in data: for _, annotations in data:
@ -368,57 +314,15 @@ def test_rehearse(component):
def test_rehearse_textcat_multilabel_listener(): def test_rehearse_textcat_multilabel_listener():
"""Test nlp.rehearse on a textcat_multilabel pipeline with a tok2vec listener""" """Test nlp.rehearse on a textcat_multilabel pipeline with a tok2vec listener"""
config = Config().from_str(TEXTCAT_MULTILABEL_LISTENER_CONFIG) config = Config().from_str(TEXTCAT_MULTILABEL_LISTENER_CONFIG)
nlp = spacy.blank("en").from_config(config) nlp = spacy.blank("en", config=config)
textcat_multilabel = nlp.get_pipe("textcat_multilabel") nlp = _optimize(nlp, "textcat_multilabel", TRAIN_DATA, False)
for label in TEXTCAT_LABELS: _optimize(nlp, "textcat_multilabel", REHEARSE_DATA, True)
textcat_multilabel.add_label(label)
nlp.initialize()
examples = []
for example in TEXTCAT_EXAMPLE_TEXTS:
example = Example.from_dict(nlp.make_doc(example[0]), example[1])
examples.append(example)
nlp.update([example])
optimizer = nlp.resume_training()
nlp.rehearse(examples, sgd=optimizer)
@pytest.mark.issue(12044) @pytest.mark.issue(12044)
def test_rehearse_textcat_listener(): def test_rehearse_ner_listener():
"""Test nlp.rehearse on a textcat pipeline with a tok2vec listener""" """Test nlp.rehearse on a ner pipeline with a tok2vec listener"""
config = Config().from_str(TEXTCAT_LISTENER_CONFIG) config = Config().from_str(NER_LISTENER_CONFIG)
nlp = spacy.blank("en").from_config(config) nlp = spacy.blank("en", config=config)
textcat = nlp.get_pipe("textcat") nlp = _optimize(nlp, "ner", TRAIN_DATA, False)
for label in TEXTCAT_LABELS: _optimize(nlp, "ner", REHEARSE_DATA, True)
textcat.add_label(label)
nlp.initialize()
examples = []
for example in TEXTCAT_EXAMPLE_TEXTS:
example = Example.from_dict(nlp.make_doc(example[0]), example[1])
examples.append(example)
nlp.update([example])
optimizer = nlp.resume_training()
nlp.rehearse(examples, sgd=optimizer)
@pytest.mark.issue(12044)
def test_rehearse_textcat():
"""Test nlp.rehearse on a textcat pipeline with an inline tok2vec component"""
config = Config().from_str(TEXTCAT_CONFIG)
nlp = spacy.blank("en").from_config(config)
textcat = nlp.get_pipe("textcat")
for label in TEXTCAT_LABELS:
textcat.add_label(label)
nlp.initialize()
examples = []
for example in TEXTCAT_EXAMPLE_TEXTS:
example = Example.from_dict(nlp.make_doc(example[0]), example[1])
examples.append(example)
nlp.update([example])
optimizer = nlp.resume_training()
nlp.rehearse(examples, sgd=optimizer)