mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 17:54:39 +03:00
Update data augmenters (#6196)
* Draft lower-case augmenter * Make warning a debug log * Update lowercase augmenter, docs and tests Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com>
This commit is contained in:
parent
d38dc466c5
commit
3c36a57e84
100
spacy/tests/training/test_augmenters.py
Normal file
100
spacy/tests/training/test_augmenters.py
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
import pytest
|
||||||
|
from spacy.training import Corpus
|
||||||
|
from spacy.training.augment import create_orth_variants_augmenter
|
||||||
|
from spacy.training.augment import create_lower_casing_augmenter
|
||||||
|
from spacy.lang.en import English
|
||||||
|
from spacy.tokens import DocBin, Doc
|
||||||
|
from contextlib import contextmanager
|
||||||
|
import random
|
||||||
|
|
||||||
|
from ..util import make_tempdir
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def make_docbin(docs, name="roundtrip.spacy"):
|
||||||
|
with make_tempdir() as tmpdir:
|
||||||
|
output_file = tmpdir / name
|
||||||
|
DocBin(docs=docs).to_disk(output_file)
|
||||||
|
yield output_file
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def nlp():
|
||||||
|
return English()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def doc(nlp):
|
||||||
|
# fmt: off
|
||||||
|
words = ["Sarah", "'s", "sister", "flew", "to", "Silicon", "Valley", "via", "London", "."]
|
||||||
|
tags = ["NNP", "POS", "NN", "VBD", "IN", "NNP", "NNP", "IN", "NNP", "."]
|
||||||
|
pos = ["PROPN", "PART", "NOUN", "VERB", "ADP", "PROPN", "PROPN", "ADP", "PROPN", "PUNCT"]
|
||||||
|
ents = ["B-PERSON", "I-PERSON", "O", "O", "O", "B-LOC", "I-LOC", "O", "B-GPE", "O"]
|
||||||
|
cats = {"TRAVEL": 1.0, "BAKING": 0.0}
|
||||||
|
# fmt: on
|
||||||
|
doc = Doc(nlp.vocab, words=words, tags=tags, pos=pos, ents=ents)
|
||||||
|
doc.cats = cats
|
||||||
|
return doc
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.filterwarnings("ignore::UserWarning")
|
||||||
|
def test_make_orth_variants(nlp, doc):
|
||||||
|
single = [
|
||||||
|
{"tags": ["NFP"], "variants": ["…", "..."]},
|
||||||
|
{"tags": [":"], "variants": ["-", "—", "–", "--", "---", "——"]},
|
||||||
|
]
|
||||||
|
augmenter = create_orth_variants_augmenter(
|
||||||
|
level=0.2, lower=0.5, orth_variants={"single": single}
|
||||||
|
)
|
||||||
|
with make_docbin([doc]) as output_file:
|
||||||
|
reader = Corpus(output_file, augmenter=augmenter)
|
||||||
|
# Due to randomness, only test that it works without errors for now
|
||||||
|
list(reader(nlp))
|
||||||
|
|
||||||
|
|
||||||
|
def test_lowercase_augmenter(nlp, doc):
|
||||||
|
augmenter = create_lower_casing_augmenter(level=1.0)
|
||||||
|
with make_docbin([doc]) as output_file:
|
||||||
|
reader = Corpus(output_file, augmenter=augmenter)
|
||||||
|
corpus = list(reader(nlp))
|
||||||
|
eg = corpus[0]
|
||||||
|
assert eg.reference.text == doc.text.lower()
|
||||||
|
assert eg.predicted.text == doc.text.lower()
|
||||||
|
ents = [(e.start, e.end, e.label) for e in doc.ents]
|
||||||
|
assert [(e.start, e.end, e.label) for e in eg.reference.ents] == ents
|
||||||
|
for ref_ent, orig_ent in zip(eg.reference.ents, doc.ents):
|
||||||
|
assert ref_ent.text == orig_ent.text.lower()
|
||||||
|
assert [t.pos_ for t in eg.reference] == [t.pos_ for t in doc]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.filterwarnings("ignore::UserWarning")
|
||||||
|
def test_custom_data_augmentation(nlp, doc):
|
||||||
|
def create_spongebob_augmenter(randomize: bool = False):
|
||||||
|
def augment(nlp, example):
|
||||||
|
text = example.text
|
||||||
|
if randomize:
|
||||||
|
ch = [c.lower() if random.random() < 0.5 else c.upper() for c in text]
|
||||||
|
else:
|
||||||
|
ch = [c.lower() if i % 2 else c.upper() for i, c in enumerate(text)]
|
||||||
|
example_dict = example.to_dict()
|
||||||
|
doc = nlp.make_doc("".join(ch))
|
||||||
|
example_dict["token_annotation"]["ORTH"] = [t.text for t in doc]
|
||||||
|
yield example
|
||||||
|
yield example.from_dict(doc, example_dict)
|
||||||
|
|
||||||
|
return augment
|
||||||
|
|
||||||
|
with make_docbin([doc]) as output_file:
|
||||||
|
reader = Corpus(output_file, augmenter=create_spongebob_augmenter())
|
||||||
|
corpus = list(reader(nlp))
|
||||||
|
orig_text = "Sarah 's sister flew to Silicon Valley via London . "
|
||||||
|
augmented = "SaRaH 's sIsTeR FlEw tO SiLiCoN VaLlEy vIa lOnDoN . "
|
||||||
|
assert corpus[0].text == orig_text
|
||||||
|
assert corpus[0].reference.text == orig_text
|
||||||
|
assert corpus[0].predicted.text == orig_text
|
||||||
|
assert corpus[1].text == augmented
|
||||||
|
assert corpus[1].reference.text == augmented
|
||||||
|
assert corpus[1].predicted.text == augmented
|
||||||
|
ents = [(e.start, e.end, e.label) for e in doc.ents]
|
||||||
|
assert [(e.start, e.end, e.label) for e in corpus[0].reference.ents] == ents
|
||||||
|
assert [(e.start, e.end, e.label) for e in corpus[1].reference.ents] == ents
|
|
@ -1,23 +1,20 @@
|
||||||
import numpy
|
import numpy
|
||||||
from spacy.training import offsets_to_biluo_tags, biluo_tags_to_offsets, Alignment
|
from spacy.training import offsets_to_biluo_tags, biluo_tags_to_offsets, Alignment
|
||||||
from spacy.training import biluo_tags_to_spans, iob_to_biluo
|
from spacy.training import biluo_tags_to_spans, iob_to_biluo
|
||||||
from spacy.training import Corpus, docs_to_json
|
from spacy.training import Corpus, docs_to_json, Example
|
||||||
from spacy.training.example import Example
|
|
||||||
from spacy.training.converters import json_to_docs
|
from spacy.training.converters import json_to_docs
|
||||||
from spacy.training.augment import create_orth_variants_augmenter
|
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.tokens import Doc, DocBin
|
from spacy.tokens import Doc, DocBin
|
||||||
from spacy.util import get_words_and_spaces, minibatch
|
from spacy.util import get_words_and_spaces, minibatch
|
||||||
from thinc.api import compounding
|
from thinc.api import compounding
|
||||||
import pytest
|
import pytest
|
||||||
import srsly
|
import srsly
|
||||||
import random
|
|
||||||
|
|
||||||
from ..util import make_tempdir
|
from ..util import make_tempdir
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def doc(en_vocab):
|
def doc():
|
||||||
nlp = English() # make sure we get a new vocab every time
|
nlp = English() # make sure we get a new vocab every time
|
||||||
# fmt: off
|
# fmt: off
|
||||||
words = ["Sarah", "'s", "sister", "flew", "to", "Silicon", "Valley", "via", "London", "."]
|
words = ["Sarah", "'s", "sister", "flew", "to", "Silicon", "Valley", "via", "London", "."]
|
||||||
|
@ -495,59 +492,6 @@ def test_roundtrip_docs_to_docbin(doc):
|
||||||
assert cats["BAKING"] == reloaded_example.reference.cats["BAKING"]
|
assert cats["BAKING"] == reloaded_example.reference.cats["BAKING"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.filterwarnings("ignore::UserWarning")
|
|
||||||
def test_make_orth_variants(doc):
|
|
||||||
nlp = English()
|
|
||||||
orth_variants = {
|
|
||||||
"single": [
|
|
||||||
{"tags": ["NFP"], "variants": ["…", "..."]},
|
|
||||||
{"tags": [":"], "variants": ["-", "—", "–", "--", "---", "——"]},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
augmenter = create_orth_variants_augmenter(
|
|
||||||
level=0.2, lower=0.5, orth_variants=orth_variants
|
|
||||||
)
|
|
||||||
with make_tempdir() as tmpdir:
|
|
||||||
output_file = tmpdir / "roundtrip.spacy"
|
|
||||||
DocBin(docs=[doc]).to_disk(output_file)
|
|
||||||
# due to randomness, test only that this runs with no errors for now
|
|
||||||
reader = Corpus(output_file, augmenter=augmenter)
|
|
||||||
list(reader(nlp))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.filterwarnings("ignore::UserWarning")
|
|
||||||
def test_custom_data_augmentation(doc):
|
|
||||||
def create_spongebob_augmenter(randomize: bool = False):
|
|
||||||
def augment(nlp, example):
|
|
||||||
text = example.text
|
|
||||||
if randomize:
|
|
||||||
ch = [c.lower() if random.random() < 0.5 else c.upper() for c in text]
|
|
||||||
else:
|
|
||||||
ch = [c.lower() if i % 2 else c.upper() for i, c in enumerate(text)]
|
|
||||||
example_dict = example.to_dict()
|
|
||||||
doc = nlp.make_doc("".join(ch))
|
|
||||||
example_dict["token_annotation"]["ORTH"] = [t.text for t in doc]
|
|
||||||
yield example
|
|
||||||
yield example.from_dict(doc, example_dict)
|
|
||||||
|
|
||||||
return augment
|
|
||||||
|
|
||||||
nlp = English()
|
|
||||||
with make_tempdir() as tmpdir:
|
|
||||||
output_file = tmpdir / "roundtrip.spacy"
|
|
||||||
DocBin(docs=[doc]).to_disk(output_file)
|
|
||||||
reader = Corpus(output_file, augmenter=create_spongebob_augmenter())
|
|
||||||
corpus = list(reader(nlp))
|
|
||||||
orig_text = "Sarah 's sister flew to Silicon Valley via London . "
|
|
||||||
augmented = "SaRaH 's sIsTeR FlEw tO SiLiCoN VaLlEy vIa lOnDoN . "
|
|
||||||
assert corpus[0].text == orig_text
|
|
||||||
assert corpus[0].reference.text == orig_text
|
|
||||||
assert corpus[0].predicted.text == orig_text
|
|
||||||
assert corpus[1].text == augmented
|
|
||||||
assert corpus[1].reference.text == augmented
|
|
||||||
assert corpus[1].predicted.text == augmented
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("Outdated")
|
@pytest.mark.skip("Outdated")
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"tokens_a,tokens_b,expected",
|
"tokens_a,tokens_b,expected",
|
||||||
|
|
|
@ -34,16 +34,47 @@ def create_orth_variants_augmenter(
|
||||||
) -> Callable[["Language", Example], Iterator[Example]]:
|
) -> Callable[["Language", Example], Iterator[Example]]:
|
||||||
"""Create a data augmentation callback that uses orth-variant replacement.
|
"""Create a data augmentation callback that uses orth-variant replacement.
|
||||||
The callback can be added to a corpus or other data iterator during training.
|
The callback can be added to a corpus or other data iterator during training.
|
||||||
|
|
||||||
|
level (float): The percentage of texts that will be augmented.
|
||||||
|
lower (float): The percentage of texts that will be lowercased.
|
||||||
|
orth_variants (Dict[str, dict]): A dictionary containing the single and
|
||||||
|
paired orth variants. Typically loaded from a JSON file.
|
||||||
|
RETURNS (Callable[[Language, Example], Iterator[Example]]): The augmenter.
|
||||||
"""
|
"""
|
||||||
return partial(
|
return partial(
|
||||||
orth_variants_augmenter, orth_variants=orth_variants, level=level, lower=lower
|
orth_variants_augmenter, orth_variants=orth_variants, level=level, lower=lower
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@registry.augmenters("spacy.lower_case.v1")
|
||||||
|
def create_lower_casing_augmenter(
|
||||||
|
level: float,
|
||||||
|
) -> Callable[["Language", Example], Iterator[Example]]:
|
||||||
|
"""Create a data augmentation callback that converts documents to lowercase.
|
||||||
|
The callback can be added to a corpus or other data iterator during training.
|
||||||
|
|
||||||
|
level (float): The percentage of texts that will be augmented.
|
||||||
|
RETURNS (Callable[[Language, Example], Iterator[Example]]): The augmenter.
|
||||||
|
"""
|
||||||
|
return partial(lower_casing_augmenter, level=level)
|
||||||
|
|
||||||
|
|
||||||
def dont_augment(nlp: "Language", example: Example) -> Iterator[Example]:
|
def dont_augment(nlp: "Language", example: Example) -> Iterator[Example]:
|
||||||
yield example
|
yield example
|
||||||
|
|
||||||
|
|
||||||
|
def lower_casing_augmenter(
|
||||||
|
nlp: "Language", example: Example, *, level: float,
|
||||||
|
) -> Iterator[Example]:
|
||||||
|
if random.random() >= level:
|
||||||
|
yield example
|
||||||
|
else:
|
||||||
|
example_dict = example.to_dict()
|
||||||
|
doc = nlp.make_doc(example.text.lower())
|
||||||
|
example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in doc]
|
||||||
|
yield example.from_dict(doc, example_dict)
|
||||||
|
|
||||||
|
|
||||||
def orth_variants_augmenter(
|
def orth_variants_augmenter(
|
||||||
nlp: "Language",
|
nlp: "Language",
|
||||||
example: Example,
|
example: Example,
|
||||||
|
|
|
@ -12,6 +12,7 @@ from .iob_utils import biluo_to_iob, offsets_to_biluo_tags, doc_to_biluo_tags
|
||||||
from .iob_utils import biluo_tags_to_spans
|
from .iob_utils import biluo_tags_to_spans
|
||||||
from ..errors import Errors, Warnings
|
from ..errors import Errors, Warnings
|
||||||
from ..pipeline._parser_internals import nonproj
|
from ..pipeline._parser_internals import nonproj
|
||||||
|
from ..util import logger
|
||||||
|
|
||||||
|
|
||||||
cpdef Doc annotations_to_doc(vocab, tok_annot, doc_annot):
|
cpdef Doc annotations_to_doc(vocab, tok_annot, doc_annot):
|
||||||
|
@ -390,7 +391,7 @@ def _fix_legacy_dict_data(example_dict):
|
||||||
if "HEAD" in token_dict and "SENT_START" in token_dict:
|
if "HEAD" in token_dict and "SENT_START" in token_dict:
|
||||||
# If heads are set, we don't also redundantly specify SENT_START.
|
# If heads are set, we don't also redundantly specify SENT_START.
|
||||||
token_dict.pop("SENT_START")
|
token_dict.pop("SENT_START")
|
||||||
warnings.warn(Warnings.W092)
|
logger.debug(Warnings.W092)
|
||||||
return {
|
return {
|
||||||
"token_annotation": token_dict,
|
"token_annotation": token_dict,
|
||||||
"doc_annotation": doc_dict
|
"doc_annotation": doc_dict
|
||||||
|
|
|
@ -689,7 +689,8 @@ Data augmentation is the process of applying small modifications to the training
|
||||||
data. It can be especially useful for punctuation and case replacement – for
|
data. It can be especially useful for punctuation and case replacement – for
|
||||||
example, if your corpus only uses smart quotes and you want to include
|
example, if your corpus only uses smart quotes and you want to include
|
||||||
variations using regular quotes, or to make the model less sensitive to
|
variations using regular quotes, or to make the model less sensitive to
|
||||||
capitalization by including a mix of capitalized and lowercase examples. See the [usage guide](/usage/training#data-augmentation) for details and examples.
|
capitalization by including a mix of capitalized and lowercase examples. See the
|
||||||
|
[usage guide](/usage/training#data-augmentation) for details and examples.
|
||||||
|
|
||||||
### spacy.orth_variants.v1 {#orth_variants tag="registered function"}
|
### spacy.orth_variants.v1 {#orth_variants tag="registered function"}
|
||||||
|
|
||||||
|
@ -707,7 +708,7 @@ capitalization by including a mix of capitalized and lowercase examples. See the
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
Create a data augmentation callback that uses orth-variant replacement. The
|
Create a data augmentation callback that uses orth-variant replacement. The
|
||||||
callback can be added to a corpus or other data iterator during training. This
|
callback can be added to a corpus or other data iterator during training. It's
|
||||||
is especially useful for punctuation and case replacement, to help generalize
|
is especially useful for punctuation and case replacement, to help generalize
|
||||||
beyond corpora that don't have smart quotes, or only have smart quotes etc.
|
beyond corpora that don't have smart quotes, or only have smart quotes etc.
|
||||||
|
|
||||||
|
@ -718,6 +719,25 @@ beyond corpora that don't have smart quotes, or only have smart quotes etc.
|
||||||
| `orth_variants` | A dictionary containing the single and paired orth variants. Typically loaded from a JSON file. See [`en_orth_variants.json`](https://github.com/explosion/spacy-lookups-data/blob/master/spacy_lookups_data/data/en_orth_variants.json) for an example. ~~Dict[str, Dict[List[Union[str, List[str]]]]]~~ |
|
| `orth_variants` | A dictionary containing the single and paired orth variants. Typically loaded from a JSON file. See [`en_orth_variants.json`](https://github.com/explosion/spacy-lookups-data/blob/master/spacy_lookups_data/data/en_orth_variants.json) for an example. ~~Dict[str, Dict[List[Union[str, List[str]]]]]~~ |
|
||||||
| **CREATES** | A function that takes the current `nlp` object and an [`Example`](/api/example) and yields augmented `Example` objects. ~~Callable[[Language, Example], Iterator[Example]]~~ |
|
| **CREATES** | A function that takes the current `nlp` object and an [`Example`](/api/example) and yields augmented `Example` objects. ~~Callable[[Language, Example], Iterator[Example]]~~ |
|
||||||
|
|
||||||
|
### spacy.lower_case.v1 {#lower_case tag="registered function"}
|
||||||
|
|
||||||
|
> #### Example config
|
||||||
|
>
|
||||||
|
> ```ini
|
||||||
|
> [corpora.train.augmenter]
|
||||||
|
> @augmenters = "spacy.lower_case.v1"
|
||||||
|
> level = 0.3
|
||||||
|
> ```
|
||||||
|
|
||||||
|
Create a data augmentation callback that lowercases documents. The callback can
|
||||||
|
be added to a corpus or other data iterator during training. It's especially
|
||||||
|
useful for making the model less sensitive to capitalization.
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ----------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `level` | The percentage of texts that will be augmented. ~~float~~ |
|
||||||
|
| **CREATES** | A function that takes the current `nlp` object and an [`Example`](/api/example) and yields augmented `Example` objects. ~~Callable[[Language, Example], Iterator[Example]]~~ |
|
||||||
|
|
||||||
## Training data and alignment {#gold source="spacy/training"}
|
## Training data and alignment {#gold source="spacy/training"}
|
||||||
|
|
||||||
### training.offsets_to_biluo_tags {#offsets_to_biluo_tags tag="function"}
|
### training.offsets_to_biluo_tags {#offsets_to_biluo_tags tag="function"}
|
||||||
|
@ -827,10 +847,10 @@ utilities.
|
||||||
### util.get_lang_class {#util.get_lang_class tag="function"}
|
### util.get_lang_class {#util.get_lang_class tag="function"}
|
||||||
|
|
||||||
Import and load a `Language` class. Allows lazy-loading
|
Import and load a `Language` class. Allows lazy-loading
|
||||||
[language data](/usage/linguistic-features#language-data) and importing languages using the
|
[language data](/usage/linguistic-features#language-data) and importing
|
||||||
two-letter language code. To add a language code for a custom language class,
|
languages using the two-letter language code. To add a language code for a
|
||||||
you can register it using the [`@registry.languages`](/api/top-level#registry)
|
custom language class, you can register it using the
|
||||||
decorator.
|
[`@registry.languages`](/api/top-level#registry) decorator.
|
||||||
|
|
||||||
> #### Example
|
> #### Example
|
||||||
>
|
>
|
||||||
|
|
Loading…
Reference in New Issue
Block a user