Update data augmenters (#6196)

* Draft lower-case augmenter

* Make warning a debug log

* Update lowercase augmenter, docs and tests

Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com>
This commit is contained in:
Ines Montani 2020-10-04 17:46:29 +02:00 committed by GitHub
parent d38dc466c5
commit 3c36a57e84
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 161 additions and 65 deletions

View File

@ -0,0 +1,100 @@
import pytest
from spacy.training import Corpus
from spacy.training.augment import create_orth_variants_augmenter
from spacy.training.augment import create_lower_casing_augmenter
from spacy.lang.en import English
from spacy.tokens import DocBin, Doc
from contextlib import contextmanager
import random
from ..util import make_tempdir
@contextmanager
def make_docbin(docs, name="roundtrip.spacy"):
with make_tempdir() as tmpdir:
output_file = tmpdir / name
DocBin(docs=docs).to_disk(output_file)
yield output_file
@pytest.fixture
def nlp():
return English()
@pytest.fixture
def doc(nlp):
# fmt: off
words = ["Sarah", "'s", "sister", "flew", "to", "Silicon", "Valley", "via", "London", "."]
tags = ["NNP", "POS", "NN", "VBD", "IN", "NNP", "NNP", "IN", "NNP", "."]
pos = ["PROPN", "PART", "NOUN", "VERB", "ADP", "PROPN", "PROPN", "ADP", "PROPN", "PUNCT"]
ents = ["B-PERSON", "I-PERSON", "O", "O", "O", "B-LOC", "I-LOC", "O", "B-GPE", "O"]
cats = {"TRAVEL": 1.0, "BAKING": 0.0}
# fmt: on
doc = Doc(nlp.vocab, words=words, tags=tags, pos=pos, ents=ents)
doc.cats = cats
return doc
@pytest.mark.filterwarnings("ignore::UserWarning")
def test_make_orth_variants(nlp, doc):
single = [
{"tags": ["NFP"], "variants": ["", "..."]},
{"tags": [":"], "variants": ["-", "", "", "--", "---", "——"]},
]
augmenter = create_orth_variants_augmenter(
level=0.2, lower=0.5, orth_variants={"single": single}
)
with make_docbin([doc]) as output_file:
reader = Corpus(output_file, augmenter=augmenter)
# Due to randomness, only test that it works without errors for now
list(reader(nlp))
def test_lowercase_augmenter(nlp, doc):
augmenter = create_lower_casing_augmenter(level=1.0)
with make_docbin([doc]) as output_file:
reader = Corpus(output_file, augmenter=augmenter)
corpus = list(reader(nlp))
eg = corpus[0]
assert eg.reference.text == doc.text.lower()
assert eg.predicted.text == doc.text.lower()
ents = [(e.start, e.end, e.label) for e in doc.ents]
assert [(e.start, e.end, e.label) for e in eg.reference.ents] == ents
for ref_ent, orig_ent in zip(eg.reference.ents, doc.ents):
assert ref_ent.text == orig_ent.text.lower()
assert [t.pos_ for t in eg.reference] == [t.pos_ for t in doc]
@pytest.mark.filterwarnings("ignore::UserWarning")
def test_custom_data_augmentation(nlp, doc):
def create_spongebob_augmenter(randomize: bool = False):
def augment(nlp, example):
text = example.text
if randomize:
ch = [c.lower() if random.random() < 0.5 else c.upper() for c in text]
else:
ch = [c.lower() if i % 2 else c.upper() for i, c in enumerate(text)]
example_dict = example.to_dict()
doc = nlp.make_doc("".join(ch))
example_dict["token_annotation"]["ORTH"] = [t.text for t in doc]
yield example
yield example.from_dict(doc, example_dict)
return augment
with make_docbin([doc]) as output_file:
reader = Corpus(output_file, augmenter=create_spongebob_augmenter())
corpus = list(reader(nlp))
orig_text = "Sarah 's sister flew to Silicon Valley via London . "
augmented = "SaRaH 's sIsTeR FlEw tO SiLiCoN VaLlEy vIa lOnDoN . "
assert corpus[0].text == orig_text
assert corpus[0].reference.text == orig_text
assert corpus[0].predicted.text == orig_text
assert corpus[1].text == augmented
assert corpus[1].reference.text == augmented
assert corpus[1].predicted.text == augmented
ents = [(e.start, e.end, e.label) for e in doc.ents]
assert [(e.start, e.end, e.label) for e in corpus[0].reference.ents] == ents
assert [(e.start, e.end, e.label) for e in corpus[1].reference.ents] == ents

View File

@ -1,23 +1,20 @@
import numpy
from spacy.training import offsets_to_biluo_tags, biluo_tags_to_offsets, Alignment
from spacy.training import biluo_tags_to_spans, iob_to_biluo
from spacy.training import Corpus, docs_to_json
from spacy.training.example import Example
from spacy.training import Corpus, docs_to_json, Example
from spacy.training.converters import json_to_docs
from spacy.training.augment import create_orth_variants_augmenter
from spacy.lang.en import English
from spacy.tokens import Doc, DocBin
from spacy.util import get_words_and_spaces, minibatch
from thinc.api import compounding
import pytest
import srsly
import random
from ..util import make_tempdir
@pytest.fixture
def doc(en_vocab):
def doc():
nlp = English() # make sure we get a new vocab every time
# fmt: off
words = ["Sarah", "'s", "sister", "flew", "to", "Silicon", "Valley", "via", "London", "."]
@ -495,59 +492,6 @@ def test_roundtrip_docs_to_docbin(doc):
assert cats["BAKING"] == reloaded_example.reference.cats["BAKING"]
@pytest.mark.filterwarnings("ignore::UserWarning")
def test_make_orth_variants(doc):
nlp = English()
orth_variants = {
"single": [
{"tags": ["NFP"], "variants": ["", "..."]},
{"tags": [":"], "variants": ["-", "", "", "--", "---", "——"]},
]
}
augmenter = create_orth_variants_augmenter(
level=0.2, lower=0.5, orth_variants=orth_variants
)
with make_tempdir() as tmpdir:
output_file = tmpdir / "roundtrip.spacy"
DocBin(docs=[doc]).to_disk(output_file)
# due to randomness, test only that this runs with no errors for now
reader = Corpus(output_file, augmenter=augmenter)
list(reader(nlp))
@pytest.mark.filterwarnings("ignore::UserWarning")
def test_custom_data_augmentation(doc):
def create_spongebob_augmenter(randomize: bool = False):
def augment(nlp, example):
text = example.text
if randomize:
ch = [c.lower() if random.random() < 0.5 else c.upper() for c in text]
else:
ch = [c.lower() if i % 2 else c.upper() for i, c in enumerate(text)]
example_dict = example.to_dict()
doc = nlp.make_doc("".join(ch))
example_dict["token_annotation"]["ORTH"] = [t.text for t in doc]
yield example
yield example.from_dict(doc, example_dict)
return augment
nlp = English()
with make_tempdir() as tmpdir:
output_file = tmpdir / "roundtrip.spacy"
DocBin(docs=[doc]).to_disk(output_file)
reader = Corpus(output_file, augmenter=create_spongebob_augmenter())
corpus = list(reader(nlp))
orig_text = "Sarah 's sister flew to Silicon Valley via London . "
augmented = "SaRaH 's sIsTeR FlEw tO SiLiCoN VaLlEy vIa lOnDoN . "
assert corpus[0].text == orig_text
assert corpus[0].reference.text == orig_text
assert corpus[0].predicted.text == orig_text
assert corpus[1].text == augmented
assert corpus[1].reference.text == augmented
assert corpus[1].predicted.text == augmented
@pytest.mark.skip("Outdated")
@pytest.mark.parametrize(
"tokens_a,tokens_b,expected",

View File

@ -34,16 +34,47 @@ def create_orth_variants_augmenter(
) -> Callable[["Language", Example], Iterator[Example]]:
"""Create a data augmentation callback that uses orth-variant replacement.
The callback can be added to a corpus or other data iterator during training.
level (float): The percentage of texts that will be augmented.
lower (float): The percentage of texts that will be lowercased.
orth_variants (Dict[str, dict]): A dictionary containing the single and
paired orth variants. Typically loaded from a JSON file.
RETURNS (Callable[[Language, Example], Iterator[Example]]): The augmenter.
"""
return partial(
orth_variants_augmenter, orth_variants=orth_variants, level=level, lower=lower
)
@registry.augmenters("spacy.lower_case.v1")
def create_lower_casing_augmenter(
level: float,
) -> Callable[["Language", Example], Iterator[Example]]:
"""Create a data augmentation callback that converts documents to lowercase.
The callback can be added to a corpus or other data iterator during training.
level (float): The percentage of texts that will be augmented.
RETURNS (Callable[[Language, Example], Iterator[Example]]): The augmenter.
"""
return partial(lower_casing_augmenter, level=level)
def dont_augment(nlp: "Language", example: Example) -> Iterator[Example]:
yield example
def lower_casing_augmenter(
nlp: "Language", example: Example, *, level: float,
) -> Iterator[Example]:
if random.random() >= level:
yield example
else:
example_dict = example.to_dict()
doc = nlp.make_doc(example.text.lower())
example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in doc]
yield example.from_dict(doc, example_dict)
def orth_variants_augmenter(
nlp: "Language",
example: Example,

View File

@ -12,6 +12,7 @@ from .iob_utils import biluo_to_iob, offsets_to_biluo_tags, doc_to_biluo_tags
from .iob_utils import biluo_tags_to_spans
from ..errors import Errors, Warnings
from ..pipeline._parser_internals import nonproj
from ..util import logger
cpdef Doc annotations_to_doc(vocab, tok_annot, doc_annot):
@ -390,7 +391,7 @@ def _fix_legacy_dict_data(example_dict):
if "HEAD" in token_dict and "SENT_START" in token_dict:
# If heads are set, we don't also redundantly specify SENT_START.
token_dict.pop("SENT_START")
warnings.warn(Warnings.W092)
logger.debug(Warnings.W092)
return {
"token_annotation": token_dict,
"doc_annotation": doc_dict

View File

@ -689,7 +689,8 @@ Data augmentation is the process of applying small modifications to the training
data. It can be especially useful for punctuation and case replacement for
example, if your corpus only uses smart quotes and you want to include
variations using regular quotes, or to make the model less sensitive to
capitalization by including a mix of capitalized and lowercase examples. See the [usage guide](/usage/training#data-augmentation) for details and examples.
capitalization by including a mix of capitalized and lowercase examples. See the
[usage guide](/usage/training#data-augmentation) for details and examples.
### spacy.orth_variants.v1 {#orth_variants tag="registered function"}
@ -707,7 +708,7 @@ capitalization by including a mix of capitalized and lowercase examples. See the
> ```
Create a data augmentation callback that uses orth-variant replacement. The
callback can be added to a corpus or other data iterator during training. This
callback can be added to a corpus or other data iterator during training. It's
is especially useful for punctuation and case replacement, to help generalize
beyond corpora that don't have smart quotes, or only have smart quotes etc.
@ -718,6 +719,25 @@ beyond corpora that don't have smart quotes, or only have smart quotes etc.
| `orth_variants` | A dictionary containing the single and paired orth variants. Typically loaded from a JSON file. See [`en_orth_variants.json`](https://github.com/explosion/spacy-lookups-data/blob/master/spacy_lookups_data/data/en_orth_variants.json) for an example. ~~Dict[str, Dict[List[Union[str, List[str]]]]]~~ |
| **CREATES** | A function that takes the current `nlp` object and an [`Example`](/api/example) and yields augmented `Example` objects. ~~Callable[[Language, Example], Iterator[Example]]~~ |
### spacy.lower_case.v1 {#lower_case tag="registered function"}
> #### Example config
>
> ```ini
> [corpora.train.augmenter]
> @augmenters = "spacy.lower_case.v1"
> level = 0.3
> ```
Create a data augmentation callback that lowercases documents. The callback can
be added to a corpus or other data iterator during training. It's especially
useful for making the model less sensitive to capitalization.
| Name | Description |
| ----------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `level` | The percentage of texts that will be augmented. ~~float~~ |
| **CREATES** | A function that takes the current `nlp` object and an [`Example`](/api/example) and yields augmented `Example` objects. ~~Callable[[Language, Example], Iterator[Example]]~~ |
## Training data and alignment {#gold source="spacy/training"}
### training.offsets_to_biluo_tags {#offsets_to_biluo_tags tag="function"}
@ -827,10 +847,10 @@ utilities.
### util.get_lang_class {#util.get_lang_class tag="function"}
Import and load a `Language` class. Allows lazy-loading
[language data](/usage/linguistic-features#language-data) and importing languages using the
two-letter language code. To add a language code for a custom language class,
you can register it using the [`@registry.languages`](/api/top-level#registry)
decorator.
[language data](/usage/linguistic-features#language-data) and importing
languages using the two-letter language code. To add a language code for a
custom language class, you can register it using the
[`@registry.languages`](/api/top-level#registry) decorator.
> #### Example
>