mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Add test for custom data augmentation
This commit is contained in:
		
							parent
							
								
									3856048437
								
							
						
					
					
						commit
						c41a4332e4
					
				| 
						 | 
				
			
			@ -7,11 +7,11 @@ from spacy.training.converters import json_to_docs
 | 
			
		|||
from spacy.training.augment import create_orth_variants_augmenter
 | 
			
		||||
from spacy.lang.en import English
 | 
			
		||||
from spacy.tokens import Doc, DocBin
 | 
			
		||||
from spacy.lookups import Lookups
 | 
			
		||||
from spacy.util import get_words_and_spaces, minibatch
 | 
			
		||||
from thinc.api import compounding
 | 
			
		||||
import pytest
 | 
			
		||||
import srsly
 | 
			
		||||
import random
 | 
			
		||||
 | 
			
		||||
from ..util import make_tempdir
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -515,6 +515,39 @@ def test_make_orth_variants(doc):
 | 
			
		|||
        list(reader(nlp))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.filterwarnings("ignore::UserWarning")
 | 
			
		||||
def test_custom_data_augmentation(doc):
 | 
			
		||||
    def create_spongebob_augmenter(randomize: bool = False):
 | 
			
		||||
        def augment(nlp, example):
 | 
			
		||||
            text = example.text
 | 
			
		||||
            if randomize:
 | 
			
		||||
                ch = [c.lower() if random.random() < 0.5 else c.upper() for c in text]
 | 
			
		||||
            else:
 | 
			
		||||
                ch = [c.lower() if i % 2 else c.upper() for i, c in enumerate(text)]
 | 
			
		||||
            example_dict = example.to_dict()
 | 
			
		||||
            doc = nlp.make_doc("".join(ch))
 | 
			
		||||
            example_dict["token_annotation"]["ORTH"] = [t.text for t in doc]
 | 
			
		||||
            yield example
 | 
			
		||||
            yield example.from_dict(doc, example_dict)
 | 
			
		||||
 | 
			
		||||
        return augment
 | 
			
		||||
 | 
			
		||||
    nlp = English()
 | 
			
		||||
    with make_tempdir() as tmpdir:
 | 
			
		||||
        output_file = tmpdir / "roundtrip.spacy"
 | 
			
		||||
        DocBin(docs=[doc]).to_disk(output_file)
 | 
			
		||||
        reader = Corpus(output_file, augmenter=create_spongebob_augmenter())
 | 
			
		||||
        corpus = list(reader(nlp))
 | 
			
		||||
    orig_text = "Sarah 's sister flew to Silicon Valley via London . "
 | 
			
		||||
    augmented = "SaRaH 's sIsTeR FlEw tO SiLiCoN VaLlEy vIa lOnDoN . "
 | 
			
		||||
    assert corpus[0].text == orig_text
 | 
			
		||||
    assert corpus[0].reference.text == orig_text
 | 
			
		||||
    assert corpus[0].predicted.text == orig_text
 | 
			
		||||
    assert corpus[1].text == augmented
 | 
			
		||||
    assert corpus[1].reference.text == augmented
 | 
			
		||||
    assert corpus[1].predicted.text == augmented
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.skip("Outdated")
 | 
			
		||||
@pytest.mark.parametrize(
 | 
			
		||||
    "tokens_a,tokens_b,expected",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user