mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	* Use isort with Black profile * isort all the things * Fix import cycles as a result of import sorting * Add DOCBIN_ALL_ATTRS type definition * Add isort to requirements * Remove isort from build dependencies check * Typo
		
			
				
	
	
		
			246 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			246 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import random
 | 
						||
from contextlib import contextmanager
 | 
						||
 | 
						||
import pytest
 | 
						||
 | 
						||
from spacy.lang.en import English
 | 
						||
from spacy.pipeline._parser_internals.nonproj import contains_cycle
 | 
						||
from spacy.tokens import Doc, DocBin, Span
 | 
						||
from spacy.training import Corpus, Example
 | 
						||
from spacy.training.augment import (
 | 
						||
    create_lower_casing_augmenter,
 | 
						||
    create_orth_variants_augmenter,
 | 
						||
    make_whitespace_variant,
 | 
						||
)
 | 
						||
 | 
						||
from ..util import make_tempdir
 | 
						||
 | 
						||
 | 
						||
@contextmanager
 | 
						||
def make_docbin(docs, name="roundtrip.spacy"):
 | 
						||
    with make_tempdir() as tmpdir:
 | 
						||
        output_file = tmpdir / name
 | 
						||
        DocBin(docs=docs).to_disk(output_file)
 | 
						||
        yield output_file
 | 
						||
 | 
						||
 | 
						||
@pytest.fixture
 | 
						||
def nlp():
 | 
						||
    return English()
 | 
						||
 | 
						||
 | 
						||
@pytest.fixture
 | 
						||
def doc(nlp):
 | 
						||
    # fmt: off
 | 
						||
    words = ["Sarah", "'s", "sister", "flew", "to", "Silicon", "Valley", "via", "London", "."]
 | 
						||
    tags = ["NNP", "POS", "NN", "VBD", "IN", "NNP", "NNP", "IN", "NNP", "."]
 | 
						||
    pos = ["PROPN", "PART", "NOUN", "VERB", "ADP", "PROPN", "PROPN", "ADP", "PROPN", "PUNCT"]
 | 
						||
    ents = ["B-PERSON", "I-PERSON", "O", "", "O", "B-LOC", "I-LOC", "O", "B-GPE", "O"]
 | 
						||
    cats = {"TRAVEL": 1.0, "BAKING": 0.0}
 | 
						||
    # fmt: on
 | 
						||
    doc = Doc(nlp.vocab, words=words, tags=tags, pos=pos, ents=ents)
 | 
						||
    doc.cats = cats
 | 
						||
    return doc
 | 
						||
 | 
						||
 | 
						||
@pytest.mark.filterwarnings("ignore::UserWarning")
 | 
						||
def test_make_orth_variants(nlp):
 | 
						||
    single = [
 | 
						||
        {"tags": ["NFP"], "variants": ["…", "..."]},
 | 
						||
        {"tags": [":"], "variants": ["-", "—", "–", "--", "---", "——"]},
 | 
						||
    ]
 | 
						||
    # fmt: off
 | 
						||
    words = ["\n\n", "A", "\t", "B", "a", "b", "…", "...", "-", "—", "–", "--", "---", "——"]
 | 
						||
    tags = ["_SP", "NN", "\t", "NN", "NN", "NN", "NFP", "NFP", ":", ":", ":", ":", ":", ":"]
 | 
						||
    # fmt: on
 | 
						||
    spaces = [True] * len(words)
 | 
						||
    spaces[0] = False
 | 
						||
    spaces[2] = False
 | 
						||
    doc = Doc(nlp.vocab, words=words, spaces=spaces, tags=tags)
 | 
						||
    augmenter = create_orth_variants_augmenter(
 | 
						||
        level=0.2, lower=0.5, orth_variants={"single": single}
 | 
						||
    )
 | 
						||
    with make_docbin([doc] * 10) as output_file:
 | 
						||
        reader = Corpus(output_file, augmenter=augmenter)
 | 
						||
        # Due to randomness, only test that it works without errors
 | 
						||
        list(reader(nlp))
 | 
						||
 | 
						||
    # check that the following settings lowercase everything
 | 
						||
    augmenter = create_orth_variants_augmenter(
 | 
						||
        level=1.0, lower=1.0, orth_variants={"single": single}
 | 
						||
    )
 | 
						||
    with make_docbin([doc] * 10) as output_file:
 | 
						||
        reader = Corpus(output_file, augmenter=augmenter)
 | 
						||
        for example in reader(nlp):
 | 
						||
            for token in example.reference:
 | 
						||
                assert token.text == token.text.lower()
 | 
						||
 | 
						||
    # check that lowercasing is applied without tags
 | 
						||
    doc = Doc(nlp.vocab, words=words, spaces=[True] * len(words))
 | 
						||
    augmenter = create_orth_variants_augmenter(
 | 
						||
        level=1.0, lower=1.0, orth_variants={"single": single}
 | 
						||
    )
 | 
						||
    with make_docbin([doc] * 10) as output_file:
 | 
						||
        reader = Corpus(output_file, augmenter=augmenter)
 | 
						||
        for example in reader(nlp):
 | 
						||
            for ex_token, doc_token in zip(example.reference, doc):
 | 
						||
                assert ex_token.text == doc_token.text.lower()
 | 
						||
 | 
						||
    # check that no lowercasing is applied with lower=0.0
 | 
						||
    doc = Doc(nlp.vocab, words=words, spaces=[True] * len(words))
 | 
						||
    augmenter = create_orth_variants_augmenter(
 | 
						||
        level=1.0, lower=0.0, orth_variants={"single": single}
 | 
						||
    )
 | 
						||
    with make_docbin([doc] * 10) as output_file:
 | 
						||
        reader = Corpus(output_file, augmenter=augmenter)
 | 
						||
        for example in reader(nlp):
 | 
						||
            for ex_token, doc_token in zip(example.reference, doc):
 | 
						||
                assert ex_token.text == doc_token.text
 | 
						||
 | 
						||
 | 
						||
def test_lowercase_augmenter(nlp, doc):
 | 
						||
    augmenter = create_lower_casing_augmenter(level=1.0)
 | 
						||
    with make_docbin([doc]) as output_file:
 | 
						||
        reader = Corpus(output_file, augmenter=augmenter)
 | 
						||
        corpus = list(reader(nlp))
 | 
						||
    eg = corpus[0]
 | 
						||
    assert eg.reference.text == doc.text.lower()
 | 
						||
    assert eg.predicted.text == doc.text.lower()
 | 
						||
    ents = [(e.start, e.end, e.label) for e in doc.ents]
 | 
						||
    assert [(e.start, e.end, e.label) for e in eg.reference.ents] == ents
 | 
						||
    for ref_ent, orig_ent in zip(eg.reference.ents, doc.ents):
 | 
						||
        assert ref_ent.text == orig_ent.text.lower()
 | 
						||
    assert [t.ent_iob for t in doc] == [t.ent_iob for t in eg.reference]
 | 
						||
    assert [t.pos_ for t in eg.reference] == [t.pos_ for t in doc]
 | 
						||
 | 
						||
    # check that augmentation works when lowercasing leads to different
 | 
						||
    # predicted tokenization
 | 
						||
    words = ["A", "B", "CCC."]
 | 
						||
    doc = Doc(nlp.vocab, words=words)
 | 
						||
    with make_docbin([doc]) as output_file:
 | 
						||
        reader = Corpus(output_file, augmenter=augmenter)
 | 
						||
        corpus = list(reader(nlp))
 | 
						||
    eg = corpus[0]
 | 
						||
    assert eg.reference.text == doc.text.lower()
 | 
						||
    assert eg.predicted.text == doc.text.lower()
 | 
						||
    assert [t.text for t in eg.reference] == [t.lower() for t in words]
 | 
						||
    assert [t.text for t in eg.predicted] == [
 | 
						||
        t.text for t in nlp.make_doc(doc.text.lower())
 | 
						||
    ]
 | 
						||
 | 
						||
 | 
						||
@pytest.mark.filterwarnings("ignore::UserWarning")
 | 
						||
def test_custom_data_augmentation(nlp, doc):
 | 
						||
    def create_spongebob_augmenter(randomize: bool = False):
 | 
						||
        def augment(nlp, example):
 | 
						||
            text = example.text
 | 
						||
            if randomize:
 | 
						||
                ch = [c.lower() if random.random() < 0.5 else c.upper() for c in text]
 | 
						||
            else:
 | 
						||
                ch = [c.lower() if i % 2 else c.upper() for i, c in enumerate(text)]
 | 
						||
            example_dict = example.to_dict()
 | 
						||
            doc = nlp.make_doc("".join(ch))
 | 
						||
            example_dict["token_annotation"]["ORTH"] = [t.text for t in doc]
 | 
						||
            yield example
 | 
						||
            yield example.from_dict(doc, example_dict)
 | 
						||
 | 
						||
        return augment
 | 
						||
 | 
						||
    with make_docbin([doc]) as output_file:
 | 
						||
        reader = Corpus(output_file, augmenter=create_spongebob_augmenter())
 | 
						||
        corpus = list(reader(nlp))
 | 
						||
    orig_text = "Sarah 's sister flew to Silicon Valley via London . "
 | 
						||
    augmented = "SaRaH 's sIsTeR FlEw tO SiLiCoN VaLlEy vIa lOnDoN . "
 | 
						||
    assert corpus[0].text == orig_text
 | 
						||
    assert corpus[0].reference.text == orig_text
 | 
						||
    assert corpus[0].predicted.text == orig_text
 | 
						||
    assert corpus[1].text == augmented
 | 
						||
    assert corpus[1].reference.text == augmented
 | 
						||
    assert corpus[1].predicted.text == augmented
 | 
						||
    ents = [(e.start, e.end, e.label) for e in doc.ents]
 | 
						||
    assert [(e.start, e.end, e.label) for e in corpus[0].reference.ents] == ents
 | 
						||
    assert [(e.start, e.end, e.label) for e in corpus[1].reference.ents] == ents
 | 
						||
 | 
						||
 | 
						||
def test_make_whitespace_variant(nlp):
 | 
						||
    # fmt: off
 | 
						||
    text = "They flew to New York City.\nThen they drove to Washington, D.C."
 | 
						||
    words = ["They", "flew", "to", "New", "York", "City", ".", "\n", "Then", "they", "drove", "to", "Washington", ",", "D.C."]
 | 
						||
    spaces = [True, True, True, True, True, False, False, False, True, True, True, True, False, True, False]
 | 
						||
    tags = ["PRP", "VBD", "IN", "NNP", "NNP", "NNP", ".", "_SP", "RB", "PRP", "VBD", "IN", "NNP", ",", "NNP"]
 | 
						||
    lemmas = ["they", "fly", "to", "New", "York", "City", ".", "\n", "then", "they", "drive", "to", "Washington", ",", "D.C."]
 | 
						||
    heads = [1, 1, 1, 4, 5, 2, 1, 10, 10, 10, 10, 10, 11, 12, 12]
 | 
						||
    deps = ["nsubj", "ROOT", "prep", "compound", "compound", "pobj", "punct", "dep", "advmod", "nsubj", "ROOT", "prep", "pobj", "punct", "appos"]
 | 
						||
    ents = ["O", "", "O", "B-GPE", "I-GPE", "I-GPE", "O", "O", "O", "O", "O", "O", "B-GPE", "O", "B-GPE"]
 | 
						||
    # fmt: on
 | 
						||
    doc = Doc(
 | 
						||
        nlp.vocab,
 | 
						||
        words=words,
 | 
						||
        spaces=spaces,
 | 
						||
        tags=tags,
 | 
						||
        lemmas=lemmas,
 | 
						||
        heads=heads,
 | 
						||
        deps=deps,
 | 
						||
        ents=ents,
 | 
						||
    )
 | 
						||
    assert doc.text == text
 | 
						||
    example = Example(nlp.make_doc(text), doc)
 | 
						||
    # whitespace is only added internally in entity spans
 | 
						||
    mod_ex = make_whitespace_variant(nlp, example, " ", 3)
 | 
						||
    assert mod_ex.reference.ents[0].text == "New York City"
 | 
						||
    mod_ex = make_whitespace_variant(nlp, example, " ", 4)
 | 
						||
    assert mod_ex.reference.ents[0].text == "New  York City"
 | 
						||
    mod_ex = make_whitespace_variant(nlp, example, " ", 5)
 | 
						||
    assert mod_ex.reference.ents[0].text == "New York  City"
 | 
						||
    mod_ex = make_whitespace_variant(nlp, example, " ", 6)
 | 
						||
    assert mod_ex.reference.ents[0].text == "New York City"
 | 
						||
    # add a space at every possible position
 | 
						||
    for i in range(len(doc) + 1):
 | 
						||
        mod_ex = make_whitespace_variant(nlp, example, " ", i)
 | 
						||
        assert mod_ex.reference[i].is_space
 | 
						||
        # adds annotation when the doc contains at least partial annotation
 | 
						||
        assert [t.tag_ for t in mod_ex.reference] == tags[:i] + ["_SP"] + tags[i:]
 | 
						||
        assert [t.lemma_ for t in mod_ex.reference] == lemmas[:i] + [" "] + lemmas[i:]
 | 
						||
        assert [t.dep_ for t in mod_ex.reference] == deps[:i] + ["dep"] + deps[i:]
 | 
						||
        # does not add partial annotation if doc does not contain this feature
 | 
						||
        assert not mod_ex.reference.has_annotation("POS")
 | 
						||
        assert not mod_ex.reference.has_annotation("MORPH")
 | 
						||
        # produces well-formed trees
 | 
						||
        assert not contains_cycle([t.head.i for t in mod_ex.reference])
 | 
						||
        assert len(list(doc.sents)) == 2
 | 
						||
        if i == 0:
 | 
						||
            assert mod_ex.reference[i].head.i == 1
 | 
						||
        else:
 | 
						||
            assert mod_ex.reference[i].head.i == i - 1
 | 
						||
        # adding another space also produces well-formed trees
 | 
						||
        for j in (3, 8, 10):
 | 
						||
            mod_ex2 = make_whitespace_variant(nlp, mod_ex, "\t\t\n", j)
 | 
						||
            assert not contains_cycle([t.head.i for t in mod_ex2.reference])
 | 
						||
            assert len(list(doc.sents)) == 2
 | 
						||
            assert mod_ex2.reference[j].head.i == j - 1
 | 
						||
        # entities are well-formed
 | 
						||
        assert len(doc.ents) == len(mod_ex.reference.ents)
 | 
						||
        # there is one token with missing entity information
 | 
						||
        assert any(t.ent_iob == 0 for t in mod_ex.reference)
 | 
						||
        for ent in mod_ex.reference.ents:
 | 
						||
            assert not ent[0].is_space
 | 
						||
            assert not ent[-1].is_space
 | 
						||
 | 
						||
    # no modifications if:
 | 
						||
    # partial dependencies
 | 
						||
    example.reference[0].dep_ = ""
 | 
						||
    mod_ex = make_whitespace_variant(nlp, example, " ", 5)
 | 
						||
    assert mod_ex.text == example.reference.text
 | 
						||
    example.reference[0].dep_ = "nsubj"  # reset
 | 
						||
 | 
						||
    # spans
 | 
						||
    example.reference.spans["spans"] = [example.reference[0:5]]
 | 
						||
    mod_ex = make_whitespace_variant(nlp, example, " ", 5)
 | 
						||
    assert mod_ex.text == example.reference.text
 | 
						||
    del example.reference.spans["spans"]  # reset
 | 
						||
 | 
						||
    # links
 | 
						||
    example.reference.ents = [Span(doc, 0, 2, label="ENT", kb_id="Q123")]
 | 
						||
    mod_ex = make_whitespace_variant(nlp, example, " ", 5)
 | 
						||
    assert mod_ex.text == example.reference.text
 |