Add test for custom data augmentation

This commit is contained in:
Ines Montani 2020-10-02 11:37:56 +02:00
parent 3856048437
commit c41a4332e4

View File

@ -7,11 +7,11 @@ from spacy.training.converters import json_to_docs
from spacy.training.augment import create_orth_variants_augmenter
from spacy.lang.en import English
from spacy.tokens import Doc, DocBin
from spacy.lookups import Lookups
from spacy.util import get_words_and_spaces, minibatch
from thinc.api import compounding
import pytest
import srsly
import random
from ..util import make_tempdir
@ -515,6 +515,39 @@ def test_make_orth_variants(doc):
list(reader(nlp))
@pytest.mark.filterwarnings("ignore::UserWarning")
def test_custom_data_augmentation(doc):
def create_spongebob_augmenter(randomize: bool = False):
def augment(nlp, example):
text = example.text
if randomize:
ch = [c.lower() if random.random() < 0.5 else c.upper() for c in text]
else:
ch = [c.lower() if i % 2 else c.upper() for i, c in enumerate(text)]
example_dict = example.to_dict()
doc = nlp.make_doc("".join(ch))
example_dict["token_annotation"]["ORTH"] = [t.text for t in doc]
yield example
yield example.from_dict(doc, example_dict)
return augment
nlp = English()
with make_tempdir() as tmpdir:
output_file = tmpdir / "roundtrip.spacy"
DocBin(docs=[doc]).to_disk(output_file)
reader = Corpus(output_file, augmenter=create_spongebob_augmenter())
corpus = list(reader(nlp))
orig_text = "Sarah 's sister flew to Silicon Valley via London . "
augmented = "SaRaH 's sIsTeR FlEw tO SiLiCoN VaLlEy vIa lOnDoN . "
assert corpus[0].text == orig_text
assert corpus[0].reference.text == orig_text
assert corpus[0].predicted.text == orig_text
assert corpus[1].text == augmented
assert corpus[1].reference.text == augmented
assert corpus[1].predicted.text == augmented
@pytest.mark.skip("Outdated")
@pytest.mark.parametrize(
"tokens_a,tokens_b,expected",