From a976da168c74227281bbdc7b2aa4ab93a0f2afba Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 28 Sep 2020 03:03:27 +0200 Subject: [PATCH] Support data augmentation in Corpus (#6155) * Support data augmentation in Corpus * Note initial docs for data augmentation * Add augmenter to quickstart * Fix flake8 * Format * Fix test * Update spacy/tests/training/test_training.py * Improve data augmentation arguments * Update templates * Move randomization out into caller * Refactor * Update spacy/training/augment.py * Update spacy/tests/training/test_training.py * Fix augment * Fix test --- spacy/cli/templates/quickstart_training.jinja | 1 + spacy/default_config.cfg | 5 ++ spacy/tests/training/test_training.py | 7 +- spacy/training/__init__.py | 1 + spacy/training/augment.py | 64 ++++++++++++------- spacy/training/corpus.py | 24 ++++++- spacy/util.py | 1 + website/docs/api/corpus.md | 1 + website/docs/usage/training.md | 11 ++++ 9 files changed, 86 insertions(+), 29 deletions(-) diff --git a/spacy/cli/templates/quickstart_training.jinja b/spacy/cli/templates/quickstart_training.jinja index 9a8b9d1d7..56faeebfa 100644 --- a/spacy/cli/templates/quickstart_training.jinja +++ b/spacy/cli/templates/quickstart_training.jinja @@ -270,6 +270,7 @@ factory = "{{ pipe }}" @readers = "spacy.Corpus.v1" path = ${paths.train} max_length = {{ 500 if hardware == "gpu" else 2000 }} +augmenter = {"@augmenters": "spacy.orth_variants.v1", "level": 0.1, "lower": 0.5} [corpora.dev] @readers = "spacy.Corpus.v1" diff --git a/spacy/default_config.cfg b/spacy/default_config.cfg index 6f8c0aa00..63a0742e3 100644 --- a/spacy/default_config.cfg +++ b/spacy/default_config.cfg @@ -35,6 +35,11 @@ gold_preproc = false max_length = 0 # Limitation on number of training examples limit = 0 +# Apply some simply data augmentation, where we replace tokens with variations. +# This is especially useful for punctuation and case replacement, to help +# generalize beyond corpora that don't have smart-quotes, or only have smart +# quotes, etc. +augmenter = {"@augmenters": "spacy.orth_variants.v1", "level": 0.1, "lower": 0.5} [corpora.dev] @readers = "spacy.Corpus.v1" diff --git a/spacy/tests/training/test_training.py b/spacy/tests/training/test_training.py index a04e6aadd..5311fae1e 100644 --- a/spacy/tests/training/test_training.py +++ b/spacy/tests/training/test_training.py @@ -4,7 +4,7 @@ from spacy.training import biluo_tags_to_spans, iob_to_biluo from spacy.training import Corpus, docs_to_json from spacy.training.example import Example from spacy.training.converters import json_to_docs -from spacy.training.augment import make_orth_variants_example +from spacy.training.augment import create_orth_variants_augmenter from spacy.lang.en import English from spacy.tokens import Doc, DocBin from spacy.util import get_words_and_spaces, minibatch @@ -496,9 +496,8 @@ def test_make_orth_variants(doc): output_file = tmpdir / "roundtrip.spacy" DocBin(docs=[doc]).to_disk(output_file) # due to randomness, test only that this runs with no errors for now - reader = Corpus(output_file) - train_example = next(reader(nlp)) - make_orth_variants_example(nlp, train_example, orth_variant_level=0.2) + reader = Corpus(output_file, augmenter=create_orth_variants_augmenter(level=0.2, lower=0.5)) + train_examples = list(reader(nlp)) @pytest.mark.skip("Outdated") diff --git a/spacy/training/__init__.py b/spacy/training/__init__.py index 9172dde25..f71a5f521 100644 --- a/spacy/training/__init__.py +++ b/spacy/training/__init__.py @@ -1,6 +1,7 @@ from .corpus import Corpus # noqa: F401 from .example import Example, validate_examples # noqa: F401 from .align import Alignment # noqa: F401 +from .augment import dont_augment, orth_variants_augmenter # noqa: F401 from .iob_utils import iob_to_biluo, biluo_to_iob # noqa: F401 from .iob_utils import offsets_to_biluo_tags, biluo_tags_to_offsets # noqa: F401 from .iob_utils import biluo_tags_to_spans, tags_to_entities # noqa: F401 diff --git a/spacy/training/augment.py b/spacy/training/augment.py index 4a01c8589..4d487ce93 100644 --- a/spacy/training/augment.py +++ b/spacy/training/augment.py @@ -1,30 +1,50 @@ +from typing import Callable import random import itertools +import copy +from functools import partial +from ..util import registry -def make_orth_variants_example(nlp, example, orth_variant_level=0.0): # TODO: naming - raw_text = example.text - orig_dict = example.to_dict() - variant_text, variant_token_annot = make_orth_variants( - nlp, raw_text, orig_dict["token_annotation"], orth_variant_level - ) - doc = nlp.make_doc(variant_text) - orig_dict["token_annotation"] = variant_token_annot - return example.from_dict(doc, orig_dict) +@registry.augmenters("spacy.dont_augment.v1") +def create_null_augmenter(): + return dont_augment -def make_orth_variants(nlp, raw_text, orig_token_dict, orth_variant_level=0.0): - if random.random() >= orth_variant_level: - return raw_text, orig_token_dict - if not orig_token_dict: - return raw_text, orig_token_dict - raw = raw_text - token_dict = orig_token_dict - lower = False - if random.random() >= 0.5: - lower = True - if raw is not None: - raw = raw.lower() +@registry.augmenters("spacy.orth_variants.v1") +def create_orth_variants_augmenter(level: float, lower: float) -> Callable: + """Create a data augmentation callback that uses orth-variant replacement. + The callback can be added to a corpus or other data iterator during training. + """ + return partial(orth_variants_augmenter, level=level, lower=lower) + + +def dont_augment(nlp, example): + yield example + + +def orth_variants_augmenter(nlp, example, *, level: float = 0.0, lower: float=0.0): + if random.random() >= level: + yield example + else: + raw_text = example.text + orig_dict = example.to_dict() + if not orig_dict["token_annotation"]: + yield example + else: + variant_text, variant_token_annot = make_orth_variants( + nlp, + raw_text, + orig_dict["token_annotation"], + lower=raw_text is not None and random.random() < lower + ) + doc = nlp.make_doc(variant_text) + orig_dict["token_annotation"] = variant_token_annot + yield example.from_dict(doc, orig_dict) + + +def make_orth_variants(nlp, raw, token_dict, *, lower: bool=False): + orig_token_dict = copy.deepcopy(token_dict) orth_variants = nlp.vocab.lookups.get_table("orth_variants", {}) ndsv = orth_variants.get("single", []) ndpv = orth_variants.get("paired", []) @@ -103,7 +123,7 @@ def make_orth_variants(nlp, raw_text, orig_token_dict, orth_variant_level=0.0): # something went wrong, abort # (add a warning message?) if not match_found: - return raw_text, orig_token_dict + return raw, orig_token_dict # add following whitespace while raw_idx < len(raw) and raw[raw_idx].isspace(): variant_raw += raw[raw_idx] diff --git a/spacy/training/corpus.py b/spacy/training/corpus.py index 12bda486e..90eb62474 100644 --- a/spacy/training/corpus.py +++ b/spacy/training/corpus.py @@ -1,9 +1,11 @@ import warnings from typing import Union, List, Iterable, Iterator, TYPE_CHECKING, Callable +from typing import Optional from pathlib import Path import srsly from .. import util +from .augment import dont_augment from .example import Example from ..errors import Warnings from ..tokens import DocBin, Doc @@ -18,9 +20,19 @@ FILE_TYPE = ".spacy" @util.registry.readers("spacy.Corpus.v1") def create_docbin_reader( - path: Path, gold_preproc: bool, max_length: int = 0, limit: int = 0 + path: Path, + gold_preproc: bool, + max_length: int = 0, + limit: int = 0, + augmenter: Optional[Callable] = None, ) -> Callable[["Language"], Iterable[Example]]: - return Corpus(path, gold_preproc=gold_preproc, max_length=max_length, limit=limit) + return Corpus( + path, + gold_preproc=gold_preproc, + max_length=max_length, + limit=limit, + augmenter=augmenter, + ) @util.registry.readers("spacy.JsonlReader.v1") @@ -70,6 +82,8 @@ class Corpus: 0, which indicates no limit. limit (int): Limit corpus to a subset of examples, e.g. for debugging. Defaults to 0, which indicates no limit. + augment (Callable[Example, Iterable[Example]]): Optional data augmentation + function, to extrapolate additional examples from your annotations. DOCS: https://nightly.spacy.io/api/corpus """ @@ -81,11 +95,13 @@ class Corpus: limit: int = 0, gold_preproc: bool = False, max_length: int = 0, + augmenter: Optional[Callable] = None, ) -> None: self.path = util.ensure_path(path) self.gold_preproc = gold_preproc self.max_length = max_length self.limit = limit + self.augmenter = augmenter if augmenter is not None else dont_augment def __call__(self, nlp: "Language") -> Iterator[Example]: """Yield examples from the data. @@ -100,7 +116,9 @@ class Corpus: examples = self.make_examples_gold_preproc(nlp, ref_docs) else: examples = self.make_examples(nlp, ref_docs) - yield from examples + for real_eg in examples: + for augmented_eg in self.augmenter(nlp, real_eg): + yield augmented_eg def _make_example( self, nlp: "Language", reference: Doc, gold_preproc: bool diff --git a/spacy/util.py b/spacy/util.py index 01232f5c5..1cc7abf57 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -81,6 +81,7 @@ class registry(thinc.registry): callbacks = catalogue.create("spacy", "callbacks") batchers = catalogue.create("spacy", "batchers", entry_points=True) readers = catalogue.create("spacy", "readers", entry_points=True) + augmenters = catalogue.create("spacy", "augmenters", entry_points=True) loggers = catalogue.create("spacy", "loggers", entry_points=True) # These are factories registered via third-party packages and the # spacy_factories entry point. This registry only exists so we can easily diff --git a/website/docs/api/corpus.md b/website/docs/api/corpus.md index 2b308d618..e7d6773e6 100644 --- a/website/docs/api/corpus.md +++ b/website/docs/api/corpus.md @@ -74,6 +74,7 @@ train/test skew. |  `gold_preproc` | Whether to set up the Example object with gold-standard sentences and tokens for the predictions. Defaults to `False`. ~~bool~~ | | `max_length` | Maximum document length. Longer documents will be split into sentences, if sentence boundaries are available. Defaults to `0` for no limit. ~~int~~ | | `limit` | Limit corpus to a subset of examples, e.g. for debugging. Defaults to `0` for no limit. ~~int~~ | +| `augmenter` | Optional data augmentation callback. ~~Callable[[Language, Example], Iterable[Example]]~~ ## Corpus.\_\_call\_\_ {#call tag="method"} diff --git a/website/docs/usage/training.md b/website/docs/usage/training.md index 54be6b367..eb02b135a 100644 --- a/website/docs/usage/training.md +++ b/website/docs/usage/training.md @@ -6,6 +6,7 @@ menu: - ['Introduction', 'basics'] - ['Quickstart', 'quickstart'] - ['Config System', 'config'] + - ['Custom Functions', 'custom-functions'] - ['Parallel Training', 'parallel-training'] - ['Internal API', 'api'] @@ -505,6 +506,16 @@ still look good. + + ## Custom Functions {#custom-functions} Registered functions in the training config files can refer to built-in