mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Update data augmenters (#6196)
* Draft lower-case augmenter * Make warning a debug log * Update lowercase augmenter, docs and tests Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com>
This commit is contained in:
parent
d38dc466c5
commit
3c36a57e84
100
spacy/tests/training/test_augmenters.py
Normal file
100
spacy/tests/training/test_augmenters.py
Normal file
|
@ -0,0 +1,100 @@
|
|||
import pytest
|
||||
from spacy.training import Corpus
|
||||
from spacy.training.augment import create_orth_variants_augmenter
|
||||
from spacy.training.augment import create_lower_casing_augmenter
|
||||
from spacy.lang.en import English
|
||||
from spacy.tokens import DocBin, Doc
|
||||
from contextlib import contextmanager
|
||||
import random
|
||||
|
||||
from ..util import make_tempdir
|
||||
|
||||
|
||||
@contextmanager
|
||||
def make_docbin(docs, name="roundtrip.spacy"):
|
||||
with make_tempdir() as tmpdir:
|
||||
output_file = tmpdir / name
|
||||
DocBin(docs=docs).to_disk(output_file)
|
||||
yield output_file
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nlp():
|
||||
return English()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def doc(nlp):
|
||||
# fmt: off
|
||||
words = ["Sarah", "'s", "sister", "flew", "to", "Silicon", "Valley", "via", "London", "."]
|
||||
tags = ["NNP", "POS", "NN", "VBD", "IN", "NNP", "NNP", "IN", "NNP", "."]
|
||||
pos = ["PROPN", "PART", "NOUN", "VERB", "ADP", "PROPN", "PROPN", "ADP", "PROPN", "PUNCT"]
|
||||
ents = ["B-PERSON", "I-PERSON", "O", "O", "O", "B-LOC", "I-LOC", "O", "B-GPE", "O"]
|
||||
cats = {"TRAVEL": 1.0, "BAKING": 0.0}
|
||||
# fmt: on
|
||||
doc = Doc(nlp.vocab, words=words, tags=tags, pos=pos, ents=ents)
|
||||
doc.cats = cats
|
||||
return doc
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::UserWarning")
|
||||
def test_make_orth_variants(nlp, doc):
|
||||
single = [
|
||||
{"tags": ["NFP"], "variants": ["…", "..."]},
|
||||
{"tags": [":"], "variants": ["-", "—", "–", "--", "---", "——"]},
|
||||
]
|
||||
augmenter = create_orth_variants_augmenter(
|
||||
level=0.2, lower=0.5, orth_variants={"single": single}
|
||||
)
|
||||
with make_docbin([doc]) as output_file:
|
||||
reader = Corpus(output_file, augmenter=augmenter)
|
||||
# Due to randomness, only test that it works without errors for now
|
||||
list(reader(nlp))
|
||||
|
||||
|
||||
def test_lowercase_augmenter(nlp, doc):
|
||||
augmenter = create_lower_casing_augmenter(level=1.0)
|
||||
with make_docbin([doc]) as output_file:
|
||||
reader = Corpus(output_file, augmenter=augmenter)
|
||||
corpus = list(reader(nlp))
|
||||
eg = corpus[0]
|
||||
assert eg.reference.text == doc.text.lower()
|
||||
assert eg.predicted.text == doc.text.lower()
|
||||
ents = [(e.start, e.end, e.label) for e in doc.ents]
|
||||
assert [(e.start, e.end, e.label) for e in eg.reference.ents] == ents
|
||||
for ref_ent, orig_ent in zip(eg.reference.ents, doc.ents):
|
||||
assert ref_ent.text == orig_ent.text.lower()
|
||||
assert [t.pos_ for t in eg.reference] == [t.pos_ for t in doc]
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::UserWarning")
|
||||
def test_custom_data_augmentation(nlp, doc):
|
||||
def create_spongebob_augmenter(randomize: bool = False):
|
||||
def augment(nlp, example):
|
||||
text = example.text
|
||||
if randomize:
|
||||
ch = [c.lower() if random.random() < 0.5 else c.upper() for c in text]
|
||||
else:
|
||||
ch = [c.lower() if i % 2 else c.upper() for i, c in enumerate(text)]
|
||||
example_dict = example.to_dict()
|
||||
doc = nlp.make_doc("".join(ch))
|
||||
example_dict["token_annotation"]["ORTH"] = [t.text for t in doc]
|
||||
yield example
|
||||
yield example.from_dict(doc, example_dict)
|
||||
|
||||
return augment
|
||||
|
||||
with make_docbin([doc]) as output_file:
|
||||
reader = Corpus(output_file, augmenter=create_spongebob_augmenter())
|
||||
corpus = list(reader(nlp))
|
||||
orig_text = "Sarah 's sister flew to Silicon Valley via London . "
|
||||
augmented = "SaRaH 's sIsTeR FlEw tO SiLiCoN VaLlEy vIa lOnDoN . "
|
||||
assert corpus[0].text == orig_text
|
||||
assert corpus[0].reference.text == orig_text
|
||||
assert corpus[0].predicted.text == orig_text
|
||||
assert corpus[1].text == augmented
|
||||
assert corpus[1].reference.text == augmented
|
||||
assert corpus[1].predicted.text == augmented
|
||||
ents = [(e.start, e.end, e.label) for e in doc.ents]
|
||||
assert [(e.start, e.end, e.label) for e in corpus[0].reference.ents] == ents
|
||||
assert [(e.start, e.end, e.label) for e in corpus[1].reference.ents] == ents
|
|
@ -1,23 +1,20 @@
|
|||
import numpy
|
||||
from spacy.training import offsets_to_biluo_tags, biluo_tags_to_offsets, Alignment
|
||||
from spacy.training import biluo_tags_to_spans, iob_to_biluo
|
||||
from spacy.training import Corpus, docs_to_json
|
||||
from spacy.training.example import Example
|
||||
from spacy.training import Corpus, docs_to_json, Example
|
||||
from spacy.training.converters import json_to_docs
|
||||
from spacy.training.augment import create_orth_variants_augmenter
|
||||
from spacy.lang.en import English
|
||||
from spacy.tokens import Doc, DocBin
|
||||
from spacy.util import get_words_and_spaces, minibatch
|
||||
from thinc.api import compounding
|
||||
import pytest
|
||||
import srsly
|
||||
import random
|
||||
|
||||
from ..util import make_tempdir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def doc(en_vocab):
|
||||
def doc():
|
||||
nlp = English() # make sure we get a new vocab every time
|
||||
# fmt: off
|
||||
words = ["Sarah", "'s", "sister", "flew", "to", "Silicon", "Valley", "via", "London", "."]
|
||||
|
@ -495,59 +492,6 @@ def test_roundtrip_docs_to_docbin(doc):
|
|||
assert cats["BAKING"] == reloaded_example.reference.cats["BAKING"]
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::UserWarning")
|
||||
def test_make_orth_variants(doc):
|
||||
nlp = English()
|
||||
orth_variants = {
|
||||
"single": [
|
||||
{"tags": ["NFP"], "variants": ["…", "..."]},
|
||||
{"tags": [":"], "variants": ["-", "—", "–", "--", "---", "——"]},
|
||||
]
|
||||
}
|
||||
augmenter = create_orth_variants_augmenter(
|
||||
level=0.2, lower=0.5, orth_variants=orth_variants
|
||||
)
|
||||
with make_tempdir() as tmpdir:
|
||||
output_file = tmpdir / "roundtrip.spacy"
|
||||
DocBin(docs=[doc]).to_disk(output_file)
|
||||
# due to randomness, test only that this runs with no errors for now
|
||||
reader = Corpus(output_file, augmenter=augmenter)
|
||||
list(reader(nlp))
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::UserWarning")
|
||||
def test_custom_data_augmentation(doc):
|
||||
def create_spongebob_augmenter(randomize: bool = False):
|
||||
def augment(nlp, example):
|
||||
text = example.text
|
||||
if randomize:
|
||||
ch = [c.lower() if random.random() < 0.5 else c.upper() for c in text]
|
||||
else:
|
||||
ch = [c.lower() if i % 2 else c.upper() for i, c in enumerate(text)]
|
||||
example_dict = example.to_dict()
|
||||
doc = nlp.make_doc("".join(ch))
|
||||
example_dict["token_annotation"]["ORTH"] = [t.text for t in doc]
|
||||
yield example
|
||||
yield example.from_dict(doc, example_dict)
|
||||
|
||||
return augment
|
||||
|
||||
nlp = English()
|
||||
with make_tempdir() as tmpdir:
|
||||
output_file = tmpdir / "roundtrip.spacy"
|
||||
DocBin(docs=[doc]).to_disk(output_file)
|
||||
reader = Corpus(output_file, augmenter=create_spongebob_augmenter())
|
||||
corpus = list(reader(nlp))
|
||||
orig_text = "Sarah 's sister flew to Silicon Valley via London . "
|
||||
augmented = "SaRaH 's sIsTeR FlEw tO SiLiCoN VaLlEy vIa lOnDoN . "
|
||||
assert corpus[0].text == orig_text
|
||||
assert corpus[0].reference.text == orig_text
|
||||
assert corpus[0].predicted.text == orig_text
|
||||
assert corpus[1].text == augmented
|
||||
assert corpus[1].reference.text == augmented
|
||||
assert corpus[1].predicted.text == augmented
|
||||
|
||||
|
||||
@pytest.mark.skip("Outdated")
|
||||
@pytest.mark.parametrize(
|
||||
"tokens_a,tokens_b,expected",
|
||||
|
|
|
@ -34,16 +34,47 @@ def create_orth_variants_augmenter(
|
|||
) -> Callable[["Language", Example], Iterator[Example]]:
|
||||
"""Create a data augmentation callback that uses orth-variant replacement.
|
||||
The callback can be added to a corpus or other data iterator during training.
|
||||
|
||||
level (float): The percentage of texts that will be augmented.
|
||||
lower (float): The percentage of texts that will be lowercased.
|
||||
orth_variants (Dict[str, dict]): A dictionary containing the single and
|
||||
paired orth variants. Typically loaded from a JSON file.
|
||||
RETURNS (Callable[[Language, Example], Iterator[Example]]): The augmenter.
|
||||
"""
|
||||
return partial(
|
||||
orth_variants_augmenter, orth_variants=orth_variants, level=level, lower=lower
|
||||
)
|
||||
|
||||
|
||||
@registry.augmenters("spacy.lower_case.v1")
|
||||
def create_lower_casing_augmenter(
|
||||
level: float,
|
||||
) -> Callable[["Language", Example], Iterator[Example]]:
|
||||
"""Create a data augmentation callback that converts documents to lowercase.
|
||||
The callback can be added to a corpus or other data iterator during training.
|
||||
|
||||
level (float): The percentage of texts that will be augmented.
|
||||
RETURNS (Callable[[Language, Example], Iterator[Example]]): The augmenter.
|
||||
"""
|
||||
return partial(lower_casing_augmenter, level=level)
|
||||
|
||||
|
||||
def dont_augment(nlp: "Language", example: Example) -> Iterator[Example]:
|
||||
yield example
|
||||
|
||||
|
||||
def lower_casing_augmenter(
|
||||
nlp: "Language", example: Example, *, level: float,
|
||||
) -> Iterator[Example]:
|
||||
if random.random() >= level:
|
||||
yield example
|
||||
else:
|
||||
example_dict = example.to_dict()
|
||||
doc = nlp.make_doc(example.text.lower())
|
||||
example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in doc]
|
||||
yield example.from_dict(doc, example_dict)
|
||||
|
||||
|
||||
def orth_variants_augmenter(
|
||||
nlp: "Language",
|
||||
example: Example,
|
||||
|
|
|
@ -12,6 +12,7 @@ from .iob_utils import biluo_to_iob, offsets_to_biluo_tags, doc_to_biluo_tags
|
|||
from .iob_utils import biluo_tags_to_spans
|
||||
from ..errors import Errors, Warnings
|
||||
from ..pipeline._parser_internals import nonproj
|
||||
from ..util import logger
|
||||
|
||||
|
||||
cpdef Doc annotations_to_doc(vocab, tok_annot, doc_annot):
|
||||
|
@ -390,7 +391,7 @@ def _fix_legacy_dict_data(example_dict):
|
|||
if "HEAD" in token_dict and "SENT_START" in token_dict:
|
||||
# If heads are set, we don't also redundantly specify SENT_START.
|
||||
token_dict.pop("SENT_START")
|
||||
warnings.warn(Warnings.W092)
|
||||
logger.debug(Warnings.W092)
|
||||
return {
|
||||
"token_annotation": token_dict,
|
||||
"doc_annotation": doc_dict
|
||||
|
|
|
@ -689,7 +689,8 @@ Data augmentation is the process of applying small modifications to the training
|
|||
data. It can be especially useful for punctuation and case replacement – for
|
||||
example, if your corpus only uses smart quotes and you want to include
|
||||
variations using regular quotes, or to make the model less sensitive to
|
||||
capitalization by including a mix of capitalized and lowercase examples. See the [usage guide](/usage/training#data-augmentation) for details and examples.
|
||||
capitalization by including a mix of capitalized and lowercase examples. See the
|
||||
[usage guide](/usage/training#data-augmentation) for details and examples.
|
||||
|
||||
### spacy.orth_variants.v1 {#orth_variants tag="registered function"}
|
||||
|
||||
|
@ -707,7 +708,7 @@ capitalization by including a mix of capitalized and lowercase examples. See the
|
|||
> ```
|
||||
|
||||
Create a data augmentation callback that uses orth-variant replacement. The
|
||||
callback can be added to a corpus or other data iterator during training. This
|
||||
callback can be added to a corpus or other data iterator during training. It's
|
||||
is especially useful for punctuation and case replacement, to help generalize
|
||||
beyond corpora that don't have smart quotes, or only have smart quotes etc.
|
||||
|
||||
|
@ -718,6 +719,25 @@ beyond corpora that don't have smart quotes, or only have smart quotes etc.
|
|||
| `orth_variants` | A dictionary containing the single and paired orth variants. Typically loaded from a JSON file. See [`en_orth_variants.json`](https://github.com/explosion/spacy-lookups-data/blob/master/spacy_lookups_data/data/en_orth_variants.json) for an example. ~~Dict[str, Dict[List[Union[str, List[str]]]]]~~ |
|
||||
| **CREATES** | A function that takes the current `nlp` object and an [`Example`](/api/example) and yields augmented `Example` objects. ~~Callable[[Language, Example], Iterator[Example]]~~ |
|
||||
|
||||
### spacy.lower_case.v1 {#lower_case tag="registered function"}
|
||||
|
||||
> #### Example config
|
||||
>
|
||||
> ```ini
|
||||
> [corpora.train.augmenter]
|
||||
> @augmenters = "spacy.lower_case.v1"
|
||||
> level = 0.3
|
||||
> ```
|
||||
|
||||
Create a data augmentation callback that lowercases documents. The callback can
|
||||
be added to a corpus or other data iterator during training. It's especially
|
||||
useful for making the model less sensitive to capitalization.
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `level` | The percentage of texts that will be augmented. ~~float~~ |
|
||||
| **CREATES** | A function that takes the current `nlp` object and an [`Example`](/api/example) and yields augmented `Example` objects. ~~Callable[[Language, Example], Iterator[Example]]~~ |
|
||||
|
||||
## Training data and alignment {#gold source="spacy/training"}
|
||||
|
||||
### training.offsets_to_biluo_tags {#offsets_to_biluo_tags tag="function"}
|
||||
|
@ -827,10 +847,10 @@ utilities.
|
|||
### util.get_lang_class {#util.get_lang_class tag="function"}
|
||||
|
||||
Import and load a `Language` class. Allows lazy-loading
|
||||
[language data](/usage/linguistic-features#language-data) and importing languages using the
|
||||
two-letter language code. To add a language code for a custom language class,
|
||||
you can register it using the [`@registry.languages`](/api/top-level#registry)
|
||||
decorator.
|
||||
[language data](/usage/linguistic-features#language-data) and importing
|
||||
languages using the two-letter language code. To add a language code for a
|
||||
custom language class, you can register it using the
|
||||
[`@registry.languages`](/api/top-level#registry) decorator.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
|
|
Loading…
Reference in New Issue
Block a user