spaCy/spacy/tests/training/test_readers.py

120 lines
3.8 KiB
Python
Raw Normal View History

from typing import Dict, Iterable, Callable
import pytest
from thinc.api import Config
from spacy import Language
2020-09-28 16:09:59 +03:00
from spacy.util import load_model_from_config, registry, resolve_dot_names
from spacy.schemas import ConfigSchemaTraining
from spacy.training import Example
def test_readers():
config_string = """
[training]
2020-09-21 11:59:07 +03:00
[corpora]
@readers = "myreader.v1"
[nlp]
lang = "en"
pipeline = ["tok2vec", "textcat"]
2020-09-21 11:59:07 +03:00
[components]
2020-09-21 11:59:07 +03:00
[components.tok2vec]
factory = "tok2vec"
2020-09-21 11:59:07 +03:00
[components.textcat]
factory = "textcat"
"""
2020-09-17 12:48:04 +03:00
@registry.readers.register("myreader.v1")
def myreader() -> Dict[str, Callable[[Language, str], Iterable[Example]]]:
annots = {"cats": {"POS": 1.0, "NEG": 0.0}}
2020-09-17 12:48:04 +03:00
def reader(nlp: Language):
doc = nlp.make_doc(f"This is an example")
return [Example.from_dict(doc, annots)]
2020-09-17 12:48:04 +03:00
return {"train": reader, "dev": reader, "extra": reader, "something": reader}
config = Config().from_str(config_string)
nlp = load_model_from_config(config, auto_fill=True)
T = registry.resolve(
nlp.config.interpolate()["training"], schema=ConfigSchemaTraining
)
dot_names = [T["train_corpus"], T["dev_corpus"]]
train_corpus, dev_corpus = resolve_dot_names(nlp.config, dot_names)
assert isinstance(train_corpus, Callable)
2020-09-28 16:09:59 +03:00
optimizer = T["optimizer"]
# simulate a training loop
2020-09-28 22:35:09 +03:00
nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer)
for example in train_corpus(nlp):
nlp.update([example], sgd=optimizer)
scores = nlp.evaluate(list(dev_corpus(nlp)))
assert scores["cats_score"] == 0.0
# ensure the pipeline runs
doc = nlp("Quick test")
assert doc.cats
corpora = {"corpora": nlp.config.interpolate()["corpora"]}
extra_corpus = registry.resolve(corpora)["corpora"]["extra"]
assert isinstance(extra_corpus, Callable)
@pytest.mark.slow
@pytest.mark.parametrize(
"reader,additional_config",
[
("ml_datasets.imdb_sentiment.v1", {"train_limit": 10, "dev_limit": 2}),
("ml_datasets.dbpedia.v1", {"train_limit": 10, "dev_limit": 2}),
("ml_datasets.cmu_movies.v1", {"limit": 10, "freq_cutoff": 200, "split": 0.8}),
],
)
def test_cat_readers(reader, additional_config):
nlp_config_string = """
[training]
seed = 0
[training.score_weights]
cats_macro_auc = 1.0
2020-09-21 11:59:07 +03:00
[corpora]
@readers = "PLACEHOLDER"
[nlp]
lang = "en"
pipeline = ["tok2vec", "textcat"]
2020-09-21 11:59:07 +03:00
[components]
2020-09-21 11:59:07 +03:00
[components.tok2vec]
factory = "tok2vec"
2020-09-21 11:59:07 +03:00
[components.textcat]
factory = "textcat"
"""
config = Config().from_str(nlp_config_string)
config["corpora"]["@readers"] = reader
config["corpora"].update(additional_config)
nlp = load_model_from_config(config, auto_fill=True)
T = registry.resolve(nlp.config["training"], schema=ConfigSchemaTraining)
dot_names = [T["train_corpus"], T["dev_corpus"]]
train_corpus, dev_corpus = resolve_dot_names(nlp.config, dot_names)
2020-09-28 16:09:59 +03:00
optimizer = T["optimizer"]
# simulate a training loop
2020-09-28 22:35:09 +03:00
nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer)
for example in train_corpus(nlp):
assert example.y.cats
# this shouldn't fail if each training example has at least one positive label
assert sorted(list(set(example.y.cats.values()))) == [0.0, 1.0]
nlp.update([example], sgd=optimizer)
# simulate performance benchmark on dev corpus
dev_examples = list(dev_corpus(nlp))
for example in dev_examples:
# this shouldn't fail if each dev example has at least one positive label
assert sorted(list(set(example.y.cats.values()))) == [0.0, 1.0]
scores = nlp.evaluate(dev_examples)
assert scores["cats_score"]
# ensure the pipeline runs
doc = nlp("Quick test")
assert doc.cats