Support data augmentation in Corpus (#6155)

* Support data augmentation in Corpus

* Note initial docs for data augmentation

* Add augmenter to quickstart

* Fix flake8

* Format

* Fix test

* Update spacy/tests/training/test_training.py

* Improve data augmentation arguments

* Update templates

* Move randomization out into caller

* Refactor

* Update spacy/training/augment.py

* Update spacy/tests/training/test_training.py

* Fix augment

* Fix test
This commit is contained in:
Matthew Honnibal 2020-09-28 03:03:27 +02:00 committed by GitHub
parent cad4dbddaa
commit a976da168c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 86 additions and 29 deletions

View File

@ -270,6 +270,7 @@ factory = "{{ pipe }}"
@readers = "spacy.Corpus.v1" @readers = "spacy.Corpus.v1"
path = ${paths.train} path = ${paths.train}
max_length = {{ 500 if hardware == "gpu" else 2000 }} max_length = {{ 500 if hardware == "gpu" else 2000 }}
augmenter = {"@augmenters": "spacy.orth_variants.v1", "level": 0.1, "lower": 0.5}
[corpora.dev] [corpora.dev]
@readers = "spacy.Corpus.v1" @readers = "spacy.Corpus.v1"

View File

@ -35,6 +35,11 @@ gold_preproc = false
max_length = 0 max_length = 0
# Limitation on number of training examples # Limitation on number of training examples
limit = 0 limit = 0
# Apply some simply data augmentation, where we replace tokens with variations.
# This is especially useful for punctuation and case replacement, to help
# generalize beyond corpora that don't have smart-quotes, or only have smart
# quotes, etc.
augmenter = {"@augmenters": "spacy.orth_variants.v1", "level": 0.1, "lower": 0.5}
[corpora.dev] [corpora.dev]
@readers = "spacy.Corpus.v1" @readers = "spacy.Corpus.v1"

View File

@ -4,7 +4,7 @@ from spacy.training import biluo_tags_to_spans, iob_to_biluo
from spacy.training import Corpus, docs_to_json from spacy.training import Corpus, docs_to_json
from spacy.training.example import Example from spacy.training.example import Example
from spacy.training.converters import json_to_docs from spacy.training.converters import json_to_docs
from spacy.training.augment import make_orth_variants_example from spacy.training.augment import create_orth_variants_augmenter
from spacy.lang.en import English from spacy.lang.en import English
from spacy.tokens import Doc, DocBin from spacy.tokens import Doc, DocBin
from spacy.util import get_words_and_spaces, minibatch from spacy.util import get_words_and_spaces, minibatch
@ -496,9 +496,8 @@ def test_make_orth_variants(doc):
output_file = tmpdir / "roundtrip.spacy" output_file = tmpdir / "roundtrip.spacy"
DocBin(docs=[doc]).to_disk(output_file) DocBin(docs=[doc]).to_disk(output_file)
# due to randomness, test only that this runs with no errors for now # due to randomness, test only that this runs with no errors for now
reader = Corpus(output_file) reader = Corpus(output_file, augmenter=create_orth_variants_augmenter(level=0.2, lower=0.5))
train_example = next(reader(nlp)) train_examples = list(reader(nlp))
make_orth_variants_example(nlp, train_example, orth_variant_level=0.2)
@pytest.mark.skip("Outdated") @pytest.mark.skip("Outdated")

View File

@ -1,6 +1,7 @@
from .corpus import Corpus # noqa: F401 from .corpus import Corpus # noqa: F401
from .example import Example, validate_examples # noqa: F401 from .example import Example, validate_examples # noqa: F401
from .align import Alignment # noqa: F401 from .align import Alignment # noqa: F401
from .augment import dont_augment, orth_variants_augmenter # noqa: F401
from .iob_utils import iob_to_biluo, biluo_to_iob # noqa: F401 from .iob_utils import iob_to_biluo, biluo_to_iob # noqa: F401
from .iob_utils import offsets_to_biluo_tags, biluo_tags_to_offsets # noqa: F401 from .iob_utils import offsets_to_biluo_tags, biluo_tags_to_offsets # noqa: F401
from .iob_utils import biluo_tags_to_spans, tags_to_entities # noqa: F401 from .iob_utils import biluo_tags_to_spans, tags_to_entities # noqa: F401

View File

@ -1,30 +1,50 @@
from typing import Callable
import random import random
import itertools import itertools
import copy
from functools import partial
from ..util import registry
def make_orth_variants_example(nlp, example, orth_variant_level=0.0): # TODO: naming @registry.augmenters("spacy.dont_augment.v1")
def create_null_augmenter():
return dont_augment
@registry.augmenters("spacy.orth_variants.v1")
def create_orth_variants_augmenter(level: float, lower: float) -> Callable:
"""Create a data augmentation callback that uses orth-variant replacement.
The callback can be added to a corpus or other data iterator during training.
"""
return partial(orth_variants_augmenter, level=level, lower=lower)
def dont_augment(nlp, example):
yield example
def orth_variants_augmenter(nlp, example, *, level: float = 0.0, lower: float=0.0):
if random.random() >= level:
yield example
else:
raw_text = example.text raw_text = example.text
orig_dict = example.to_dict() orig_dict = example.to_dict()
if not orig_dict["token_annotation"]:
yield example
else:
variant_text, variant_token_annot = make_orth_variants( variant_text, variant_token_annot = make_orth_variants(
nlp, raw_text, orig_dict["token_annotation"], orth_variant_level nlp,
raw_text,
orig_dict["token_annotation"],
lower=raw_text is not None and random.random() < lower
) )
doc = nlp.make_doc(variant_text) doc = nlp.make_doc(variant_text)
orig_dict["token_annotation"] = variant_token_annot orig_dict["token_annotation"] = variant_token_annot
return example.from_dict(doc, orig_dict) yield example.from_dict(doc, orig_dict)
def make_orth_variants(nlp, raw_text, orig_token_dict, orth_variant_level=0.0): def make_orth_variants(nlp, raw, token_dict, *, lower: bool=False):
if random.random() >= orth_variant_level: orig_token_dict = copy.deepcopy(token_dict)
return raw_text, orig_token_dict
if not orig_token_dict:
return raw_text, orig_token_dict
raw = raw_text
token_dict = orig_token_dict
lower = False
if random.random() >= 0.5:
lower = True
if raw is not None:
raw = raw.lower()
orth_variants = nlp.vocab.lookups.get_table("orth_variants", {}) orth_variants = nlp.vocab.lookups.get_table("orth_variants", {})
ndsv = orth_variants.get("single", []) ndsv = orth_variants.get("single", [])
ndpv = orth_variants.get("paired", []) ndpv = orth_variants.get("paired", [])
@ -103,7 +123,7 @@ def make_orth_variants(nlp, raw_text, orig_token_dict, orth_variant_level=0.0):
# something went wrong, abort # something went wrong, abort
# (add a warning message?) # (add a warning message?)
if not match_found: if not match_found:
return raw_text, orig_token_dict return raw, orig_token_dict
# add following whitespace # add following whitespace
while raw_idx < len(raw) and raw[raw_idx].isspace(): while raw_idx < len(raw) and raw[raw_idx].isspace():
variant_raw += raw[raw_idx] variant_raw += raw[raw_idx]

View File

@ -1,9 +1,11 @@
import warnings import warnings
from typing import Union, List, Iterable, Iterator, TYPE_CHECKING, Callable from typing import Union, List, Iterable, Iterator, TYPE_CHECKING, Callable
from typing import Optional
from pathlib import Path from pathlib import Path
import srsly import srsly
from .. import util from .. import util
from .augment import dont_augment
from .example import Example from .example import Example
from ..errors import Warnings from ..errors import Warnings
from ..tokens import DocBin, Doc from ..tokens import DocBin, Doc
@ -18,9 +20,19 @@ FILE_TYPE = ".spacy"
@util.registry.readers("spacy.Corpus.v1") @util.registry.readers("spacy.Corpus.v1")
def create_docbin_reader( def create_docbin_reader(
path: Path, gold_preproc: bool, max_length: int = 0, limit: int = 0 path: Path,
gold_preproc: bool,
max_length: int = 0,
limit: int = 0,
augmenter: Optional[Callable] = None,
) -> Callable[["Language"], Iterable[Example]]: ) -> Callable[["Language"], Iterable[Example]]:
return Corpus(path, gold_preproc=gold_preproc, max_length=max_length, limit=limit) return Corpus(
path,
gold_preproc=gold_preproc,
max_length=max_length,
limit=limit,
augmenter=augmenter,
)
@util.registry.readers("spacy.JsonlReader.v1") @util.registry.readers("spacy.JsonlReader.v1")
@ -70,6 +82,8 @@ class Corpus:
0, which indicates no limit. 0, which indicates no limit.
limit (int): Limit corpus to a subset of examples, e.g. for debugging. limit (int): Limit corpus to a subset of examples, e.g. for debugging.
Defaults to 0, which indicates no limit. Defaults to 0, which indicates no limit.
augment (Callable[Example, Iterable[Example]]): Optional data augmentation
function, to extrapolate additional examples from your annotations.
DOCS: https://nightly.spacy.io/api/corpus DOCS: https://nightly.spacy.io/api/corpus
""" """
@ -81,11 +95,13 @@ class Corpus:
limit: int = 0, limit: int = 0,
gold_preproc: bool = False, gold_preproc: bool = False,
max_length: int = 0, max_length: int = 0,
augmenter: Optional[Callable] = None,
) -> None: ) -> None:
self.path = util.ensure_path(path) self.path = util.ensure_path(path)
self.gold_preproc = gold_preproc self.gold_preproc = gold_preproc
self.max_length = max_length self.max_length = max_length
self.limit = limit self.limit = limit
self.augmenter = augmenter if augmenter is not None else dont_augment
def __call__(self, nlp: "Language") -> Iterator[Example]: def __call__(self, nlp: "Language") -> Iterator[Example]:
"""Yield examples from the data. """Yield examples from the data.
@ -100,7 +116,9 @@ class Corpus:
examples = self.make_examples_gold_preproc(nlp, ref_docs) examples = self.make_examples_gold_preproc(nlp, ref_docs)
else: else:
examples = self.make_examples(nlp, ref_docs) examples = self.make_examples(nlp, ref_docs)
yield from examples for real_eg in examples:
for augmented_eg in self.augmenter(nlp, real_eg):
yield augmented_eg
def _make_example( def _make_example(
self, nlp: "Language", reference: Doc, gold_preproc: bool self, nlp: "Language", reference: Doc, gold_preproc: bool

View File

@ -81,6 +81,7 @@ class registry(thinc.registry):
callbacks = catalogue.create("spacy", "callbacks") callbacks = catalogue.create("spacy", "callbacks")
batchers = catalogue.create("spacy", "batchers", entry_points=True) batchers = catalogue.create("spacy", "batchers", entry_points=True)
readers = catalogue.create("spacy", "readers", entry_points=True) readers = catalogue.create("spacy", "readers", entry_points=True)
augmenters = catalogue.create("spacy", "augmenters", entry_points=True)
loggers = catalogue.create("spacy", "loggers", entry_points=True) loggers = catalogue.create("spacy", "loggers", entry_points=True)
# These are factories registered via third-party packages and the # These are factories registered via third-party packages and the
# spacy_factories entry point. This registry only exists so we can easily # spacy_factories entry point. This registry only exists so we can easily

View File

@ -74,6 +74,7 @@ train/test skew.
|  `gold_preproc` | Whether to set up the Example object with gold-standard sentences and tokens for the predictions. Defaults to `False`. ~~bool~~ | |  `gold_preproc` | Whether to set up the Example object with gold-standard sentences and tokens for the predictions. Defaults to `False`. ~~bool~~ |
| `max_length` | Maximum document length. Longer documents will be split into sentences, if sentence boundaries are available. Defaults to `0` for no limit. ~~int~~ | | `max_length` | Maximum document length. Longer documents will be split into sentences, if sentence boundaries are available. Defaults to `0` for no limit. ~~int~~ |
| `limit` | Limit corpus to a subset of examples, e.g. for debugging. Defaults to `0` for no limit. ~~int~~ | | `limit` | Limit corpus to a subset of examples, e.g. for debugging. Defaults to `0` for no limit. ~~int~~ |
| `augmenter` | Optional data augmentation callback. ~~Callable[[Language, Example], Iterable[Example]]~~
## Corpus.\_\_call\_\_ {#call tag="method"} ## Corpus.\_\_call\_\_ {#call tag="method"}

View File

@ -6,6 +6,7 @@ menu:
- ['Introduction', 'basics'] - ['Introduction', 'basics']
- ['Quickstart', 'quickstart'] - ['Quickstart', 'quickstart']
- ['Config System', 'config'] - ['Config System', 'config']
<!-- - ['Data Utilities', 'data'] -->
- ['Custom Functions', 'custom-functions'] - ['Custom Functions', 'custom-functions']
- ['Parallel Training', 'parallel-training'] - ['Parallel Training', 'parallel-training']
- ['Internal API', 'api'] - ['Internal API', 'api']
@ -505,6 +506,16 @@ still look good.
</Accordion> </Accordion>
<!--
## Data Utilities {#data-utilities}
* spacy convert
* The [corpora] block
* Custom corpus class
* Minibatching
* Data augmentation
-->
## Custom Functions {#custom-functions} ## Custom Functions {#custom-functions}
Registered functions in the training config files can refer to built-in Registered functions in the training config files can refer to built-in