mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Support data augmentation in Corpus (#6155)
* Support data augmentation in Corpus * Note initial docs for data augmentation * Add augmenter to quickstart * Fix flake8 * Format * Fix test * Update spacy/tests/training/test_training.py * Improve data augmentation arguments * Update templates * Move randomization out into caller * Refactor * Update spacy/training/augment.py * Update spacy/tests/training/test_training.py * Fix augment * Fix test
This commit is contained in:
parent
cad4dbddaa
commit
a976da168c
|
@ -270,6 +270,7 @@ factory = "{{ pipe }}"
|
|||
@readers = "spacy.Corpus.v1"
|
||||
path = ${paths.train}
|
||||
max_length = {{ 500 if hardware == "gpu" else 2000 }}
|
||||
augmenter = {"@augmenters": "spacy.orth_variants.v1", "level": 0.1, "lower": 0.5}
|
||||
|
||||
[corpora.dev]
|
||||
@readers = "spacy.Corpus.v1"
|
||||
|
|
|
@ -35,6 +35,11 @@ gold_preproc = false
|
|||
max_length = 0
|
||||
# Limitation on number of training examples
|
||||
limit = 0
|
||||
# Apply some simply data augmentation, where we replace tokens with variations.
|
||||
# This is especially useful for punctuation and case replacement, to help
|
||||
# generalize beyond corpora that don't have smart-quotes, or only have smart
|
||||
# quotes, etc.
|
||||
augmenter = {"@augmenters": "spacy.orth_variants.v1", "level": 0.1, "lower": 0.5}
|
||||
|
||||
[corpora.dev]
|
||||
@readers = "spacy.Corpus.v1"
|
||||
|
|
|
@ -4,7 +4,7 @@ from spacy.training import biluo_tags_to_spans, iob_to_biluo
|
|||
from spacy.training import Corpus, docs_to_json
|
||||
from spacy.training.example import Example
|
||||
from spacy.training.converters import json_to_docs
|
||||
from spacy.training.augment import make_orth_variants_example
|
||||
from spacy.training.augment import create_orth_variants_augmenter
|
||||
from spacy.lang.en import English
|
||||
from spacy.tokens import Doc, DocBin
|
||||
from spacy.util import get_words_and_spaces, minibatch
|
||||
|
@ -496,9 +496,8 @@ def test_make_orth_variants(doc):
|
|||
output_file = tmpdir / "roundtrip.spacy"
|
||||
DocBin(docs=[doc]).to_disk(output_file)
|
||||
# due to randomness, test only that this runs with no errors for now
|
||||
reader = Corpus(output_file)
|
||||
train_example = next(reader(nlp))
|
||||
make_orth_variants_example(nlp, train_example, orth_variant_level=0.2)
|
||||
reader = Corpus(output_file, augmenter=create_orth_variants_augmenter(level=0.2, lower=0.5))
|
||||
train_examples = list(reader(nlp))
|
||||
|
||||
|
||||
@pytest.mark.skip("Outdated")
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from .corpus import Corpus # noqa: F401
|
||||
from .example import Example, validate_examples # noqa: F401
|
||||
from .align import Alignment # noqa: F401
|
||||
from .augment import dont_augment, orth_variants_augmenter # noqa: F401
|
||||
from .iob_utils import iob_to_biluo, biluo_to_iob # noqa: F401
|
||||
from .iob_utils import offsets_to_biluo_tags, biluo_tags_to_offsets # noqa: F401
|
||||
from .iob_utils import biluo_tags_to_spans, tags_to_entities # noqa: F401
|
||||
|
|
|
@ -1,30 +1,50 @@
|
|||
from typing import Callable
|
||||
import random
|
||||
import itertools
|
||||
import copy
|
||||
from functools import partial
|
||||
from ..util import registry
|
||||
|
||||
|
||||
def make_orth_variants_example(nlp, example, orth_variant_level=0.0): # TODO: naming
|
||||
raw_text = example.text
|
||||
orig_dict = example.to_dict()
|
||||
variant_text, variant_token_annot = make_orth_variants(
|
||||
nlp, raw_text, orig_dict["token_annotation"], orth_variant_level
|
||||
)
|
||||
doc = nlp.make_doc(variant_text)
|
||||
orig_dict["token_annotation"] = variant_token_annot
|
||||
return example.from_dict(doc, orig_dict)
|
||||
@registry.augmenters("spacy.dont_augment.v1")
|
||||
def create_null_augmenter():
|
||||
return dont_augment
|
||||
|
||||
|
||||
def make_orth_variants(nlp, raw_text, orig_token_dict, orth_variant_level=0.0):
|
||||
if random.random() >= orth_variant_level:
|
||||
return raw_text, orig_token_dict
|
||||
if not orig_token_dict:
|
||||
return raw_text, orig_token_dict
|
||||
raw = raw_text
|
||||
token_dict = orig_token_dict
|
||||
lower = False
|
||||
if random.random() >= 0.5:
|
||||
lower = True
|
||||
if raw is not None:
|
||||
raw = raw.lower()
|
||||
@registry.augmenters("spacy.orth_variants.v1")
|
||||
def create_orth_variants_augmenter(level: float, lower: float) -> Callable:
|
||||
"""Create a data augmentation callback that uses orth-variant replacement.
|
||||
The callback can be added to a corpus or other data iterator during training.
|
||||
"""
|
||||
return partial(orth_variants_augmenter, level=level, lower=lower)
|
||||
|
||||
|
||||
def dont_augment(nlp, example):
|
||||
yield example
|
||||
|
||||
|
||||
def orth_variants_augmenter(nlp, example, *, level: float = 0.0, lower: float=0.0):
|
||||
if random.random() >= level:
|
||||
yield example
|
||||
else:
|
||||
raw_text = example.text
|
||||
orig_dict = example.to_dict()
|
||||
if not orig_dict["token_annotation"]:
|
||||
yield example
|
||||
else:
|
||||
variant_text, variant_token_annot = make_orth_variants(
|
||||
nlp,
|
||||
raw_text,
|
||||
orig_dict["token_annotation"],
|
||||
lower=raw_text is not None and random.random() < lower
|
||||
)
|
||||
doc = nlp.make_doc(variant_text)
|
||||
orig_dict["token_annotation"] = variant_token_annot
|
||||
yield example.from_dict(doc, orig_dict)
|
||||
|
||||
|
||||
def make_orth_variants(nlp, raw, token_dict, *, lower: bool=False):
|
||||
orig_token_dict = copy.deepcopy(token_dict)
|
||||
orth_variants = nlp.vocab.lookups.get_table("orth_variants", {})
|
||||
ndsv = orth_variants.get("single", [])
|
||||
ndpv = orth_variants.get("paired", [])
|
||||
|
@ -103,7 +123,7 @@ def make_orth_variants(nlp, raw_text, orig_token_dict, orth_variant_level=0.0):
|
|||
# something went wrong, abort
|
||||
# (add a warning message?)
|
||||
if not match_found:
|
||||
return raw_text, orig_token_dict
|
||||
return raw, orig_token_dict
|
||||
# add following whitespace
|
||||
while raw_idx < len(raw) and raw[raw_idx].isspace():
|
||||
variant_raw += raw[raw_idx]
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import warnings
|
||||
from typing import Union, List, Iterable, Iterator, TYPE_CHECKING, Callable
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
import srsly
|
||||
|
||||
from .. import util
|
||||
from .augment import dont_augment
|
||||
from .example import Example
|
||||
from ..errors import Warnings
|
||||
from ..tokens import DocBin, Doc
|
||||
|
@ -18,9 +20,19 @@ FILE_TYPE = ".spacy"
|
|||
|
||||
@util.registry.readers("spacy.Corpus.v1")
|
||||
def create_docbin_reader(
|
||||
path: Path, gold_preproc: bool, max_length: int = 0, limit: int = 0
|
||||
path: Path,
|
||||
gold_preproc: bool,
|
||||
max_length: int = 0,
|
||||
limit: int = 0,
|
||||
augmenter: Optional[Callable] = None,
|
||||
) -> Callable[["Language"], Iterable[Example]]:
|
||||
return Corpus(path, gold_preproc=gold_preproc, max_length=max_length, limit=limit)
|
||||
return Corpus(
|
||||
path,
|
||||
gold_preproc=gold_preproc,
|
||||
max_length=max_length,
|
||||
limit=limit,
|
||||
augmenter=augmenter,
|
||||
)
|
||||
|
||||
|
||||
@util.registry.readers("spacy.JsonlReader.v1")
|
||||
|
@ -70,6 +82,8 @@ class Corpus:
|
|||
0, which indicates no limit.
|
||||
limit (int): Limit corpus to a subset of examples, e.g. for debugging.
|
||||
Defaults to 0, which indicates no limit.
|
||||
augment (Callable[Example, Iterable[Example]]): Optional data augmentation
|
||||
function, to extrapolate additional examples from your annotations.
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/corpus
|
||||
"""
|
||||
|
@ -81,11 +95,13 @@ class Corpus:
|
|||
limit: int = 0,
|
||||
gold_preproc: bool = False,
|
||||
max_length: int = 0,
|
||||
augmenter: Optional[Callable] = None,
|
||||
) -> None:
|
||||
self.path = util.ensure_path(path)
|
||||
self.gold_preproc = gold_preproc
|
||||
self.max_length = max_length
|
||||
self.limit = limit
|
||||
self.augmenter = augmenter if augmenter is not None else dont_augment
|
||||
|
||||
def __call__(self, nlp: "Language") -> Iterator[Example]:
|
||||
"""Yield examples from the data.
|
||||
|
@ -100,7 +116,9 @@ class Corpus:
|
|||
examples = self.make_examples_gold_preproc(nlp, ref_docs)
|
||||
else:
|
||||
examples = self.make_examples(nlp, ref_docs)
|
||||
yield from examples
|
||||
for real_eg in examples:
|
||||
for augmented_eg in self.augmenter(nlp, real_eg):
|
||||
yield augmented_eg
|
||||
|
||||
def _make_example(
|
||||
self, nlp: "Language", reference: Doc, gold_preproc: bool
|
||||
|
|
|
@ -81,6 +81,7 @@ class registry(thinc.registry):
|
|||
callbacks = catalogue.create("spacy", "callbacks")
|
||||
batchers = catalogue.create("spacy", "batchers", entry_points=True)
|
||||
readers = catalogue.create("spacy", "readers", entry_points=True)
|
||||
augmenters = catalogue.create("spacy", "augmenters", entry_points=True)
|
||||
loggers = catalogue.create("spacy", "loggers", entry_points=True)
|
||||
# These are factories registered via third-party packages and the
|
||||
# spacy_factories entry point. This registry only exists so we can easily
|
||||
|
|
|
@ -74,6 +74,7 @@ train/test skew.
|
|||
| `gold_preproc` | Whether to set up the Example object with gold-standard sentences and tokens for the predictions. Defaults to `False`. ~~bool~~ |
|
||||
| `max_length` | Maximum document length. Longer documents will be split into sentences, if sentence boundaries are available. Defaults to `0` for no limit. ~~int~~ |
|
||||
| `limit` | Limit corpus to a subset of examples, e.g. for debugging. Defaults to `0` for no limit. ~~int~~ |
|
||||
| `augmenter` | Optional data augmentation callback. ~~Callable[[Language, Example], Iterable[Example]]~~
|
||||
|
||||
## Corpus.\_\_call\_\_ {#call tag="method"}
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ menu:
|
|||
- ['Introduction', 'basics']
|
||||
- ['Quickstart', 'quickstart']
|
||||
- ['Config System', 'config']
|
||||
<!-- - ['Data Utilities', 'data'] -->
|
||||
- ['Custom Functions', 'custom-functions']
|
||||
- ['Parallel Training', 'parallel-training']
|
||||
- ['Internal API', 'api']
|
||||
|
@ -505,6 +506,16 @@ still look good.
|
|||
|
||||
</Accordion>
|
||||
|
||||
<!--
|
||||
## Data Utilities {#data-utilities}
|
||||
|
||||
* spacy convert
|
||||
* The [corpora] block
|
||||
* Custom corpus class
|
||||
* Minibatching
|
||||
* Data augmentation
|
||||
-->
|
||||
|
||||
## Custom Functions {#custom-functions}
|
||||
|
||||
Registered functions in the training config files can refer to built-in
|
||||
|
|
Loading…
Reference in New Issue
Block a user