spaCy/spacy/tests/regression/test_issue6908.py
René Octavio Queiroz Dias 999ff03b19
fix: Fix textcat labels to expect a Optional[Iterable[str]] instead of Optional[Dict] (#6911)
* docs: Add agreement

* bug: Regression test

Issue #6908

* fix: Changed from Dict to Iterable[str]

Fix #6908

* Update test to use make_tempdir

* fix: Fix WindowsPath error

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
2021-02-04 23:37:13 +01:00

103 lines
2.1 KiB
Python

import pytest
import spacy
from spacy.language import Language
from spacy.tokens import DocBin
from spacy import util
from spacy.schemas import ConfigSchemaInit
from spacy.training.initialize import init_nlp
from ..util import make_tempdir
TEXTCAT_WITH_LABELS_ARRAY_CONFIG = """
[paths]
train = "TRAIN_PLACEHOLDER"
raw = null
init_tok2vec = null
vectors = null
[system]
seed = 0
gpu_allocator = null
[nlp]
lang = "en"
pipeline = ["textcat"]
tokenizer = {"@tokenizers":"spacy.Tokenizer.v1"}
disabled = []
before_creation = null
after_creation = null
after_pipeline_creation = null
batch_size = 1000
[components]
[components.textcat]
factory = "TEXTCAT_PLACEHOLDER"
[corpora]
[corpora.train]
@readers = "spacy.Corpus.v1"
path = ${paths:train}
[corpora.dev]
@readers = "spacy.Corpus.v1"
path = ${paths:train}
[training]
train_corpus = "corpora.train"
dev_corpus = "corpora.dev"
seed = ${system.seed}
gpu_allocator = ${system.gpu_allocator}
frozen_components = []
before_to_disk = null
[pretraining]
[initialize]
vectors = ${paths.vectors}
init_tok2vec = ${paths.init_tok2vec}
vocab_data = null
lookups = null
before_init = null
after_init = null
[initialize.components]
[initialize.components.textcat]
labels = ['label1', 'label2']
[initialize.tokenizer]
"""
@pytest.mark.parametrize(
"component_name",
["textcat", "textcat_multilabel"],
)
def test_textcat_initialize_labels_validation(component_name):
"""Test intializing textcat with labels in a list"""
def create_data(out_file):
nlp = spacy.blank("en")
doc = nlp.make_doc("Some text")
doc.cats = {"label1": 0, "label2": 1}
out_data = DocBin(docs=[doc]).to_bytes()
with out_file.open("wb") as file_:
file_.write(out_data)
with make_tempdir() as tmp_path:
train_path = tmp_path / "train.spacy"
create_data(train_path)
config_str = TEXTCAT_WITH_LABELS_ARRAY_CONFIG.replace(
"TEXTCAT_PLACEHOLDER", component_name
)
config_str = config_str.replace("TRAIN_PLACEHOLDER", train_path.as_posix())
config = util.load_config_from_str(config_str)
init_nlp(config)