mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
Add test for custom data augmentation
This commit is contained in:
parent
3856048437
commit
c41a4332e4
|
@ -7,11 +7,11 @@ from spacy.training.converters import json_to_docs
|
||||||
from spacy.training.augment import create_orth_variants_augmenter
|
from spacy.training.augment import create_orth_variants_augmenter
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.tokens import Doc, DocBin
|
from spacy.tokens import Doc, DocBin
|
||||||
from spacy.lookups import Lookups
|
|
||||||
from spacy.util import get_words_and_spaces, minibatch
|
from spacy.util import get_words_and_spaces, minibatch
|
||||||
from thinc.api import compounding
|
from thinc.api import compounding
|
||||||
import pytest
|
import pytest
|
||||||
import srsly
|
import srsly
|
||||||
|
import random
|
||||||
|
|
||||||
from ..util import make_tempdir
|
from ..util import make_tempdir
|
||||||
|
|
||||||
|
@ -515,6 +515,39 @@ def test_make_orth_variants(doc):
|
||||||
list(reader(nlp))
|
list(reader(nlp))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.filterwarnings("ignore::UserWarning")
|
||||||
|
def test_custom_data_augmentation(doc):
|
||||||
|
def create_spongebob_augmenter(randomize: bool = False):
|
||||||
|
def augment(nlp, example):
|
||||||
|
text = example.text
|
||||||
|
if randomize:
|
||||||
|
ch = [c.lower() if random.random() < 0.5 else c.upper() for c in text]
|
||||||
|
else:
|
||||||
|
ch = [c.lower() if i % 2 else c.upper() for i, c in enumerate(text)]
|
||||||
|
example_dict = example.to_dict()
|
||||||
|
doc = nlp.make_doc("".join(ch))
|
||||||
|
example_dict["token_annotation"]["ORTH"] = [t.text for t in doc]
|
||||||
|
yield example
|
||||||
|
yield example.from_dict(doc, example_dict)
|
||||||
|
|
||||||
|
return augment
|
||||||
|
|
||||||
|
nlp = English()
|
||||||
|
with make_tempdir() as tmpdir:
|
||||||
|
output_file = tmpdir / "roundtrip.spacy"
|
||||||
|
DocBin(docs=[doc]).to_disk(output_file)
|
||||||
|
reader = Corpus(output_file, augmenter=create_spongebob_augmenter())
|
||||||
|
corpus = list(reader(nlp))
|
||||||
|
orig_text = "Sarah 's sister flew to Silicon Valley via London . "
|
||||||
|
augmented = "SaRaH 's sIsTeR FlEw tO SiLiCoN VaLlEy vIa lOnDoN . "
|
||||||
|
assert corpus[0].text == orig_text
|
||||||
|
assert corpus[0].reference.text == orig_text
|
||||||
|
assert corpus[0].predicted.text == orig_text
|
||||||
|
assert corpus[1].text == augmented
|
||||||
|
assert corpus[1].reference.text == augmented
|
||||||
|
assert corpus[1].predicted.text == augmented
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("Outdated")
|
@pytest.mark.skip("Outdated")
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"tokens_a,tokens_b,expected",
|
"tokens_a,tokens_b,expected",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user