mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-11 17:56:30 +03:00
Support data augmentation in Corpus (#6155)
* Support data augmentation in Corpus * Note initial docs for data augmentation * Add augmenter to quickstart * Fix flake8 * Format * Fix test * Update spacy/tests/training/test_training.py * Improve data augmentation arguments * Update templates * Move randomization out into caller * Refactor * Update spacy/training/augment.py * Update spacy/tests/training/test_training.py * Fix augment * Fix test
This commit is contained in:
parent
cad4dbddaa
commit
a976da168c
|
@ -270,6 +270,7 @@ factory = "{{ pipe }}"
|
||||||
@readers = "spacy.Corpus.v1"
|
@readers = "spacy.Corpus.v1"
|
||||||
path = ${paths.train}
|
path = ${paths.train}
|
||||||
max_length = {{ 500 if hardware == "gpu" else 2000 }}
|
max_length = {{ 500 if hardware == "gpu" else 2000 }}
|
||||||
|
augmenter = {"@augmenters": "spacy.orth_variants.v1", "level": 0.1, "lower": 0.5}
|
||||||
|
|
||||||
[corpora.dev]
|
[corpora.dev]
|
||||||
@readers = "spacy.Corpus.v1"
|
@readers = "spacy.Corpus.v1"
|
||||||
|
|
|
@ -35,6 +35,11 @@ gold_preproc = false
|
||||||
max_length = 0
|
max_length = 0
|
||||||
# Limitation on number of training examples
|
# Limitation on number of training examples
|
||||||
limit = 0
|
limit = 0
|
||||||
|
# Apply some simply data augmentation, where we replace tokens with variations.
|
||||||
|
# This is especially useful for punctuation and case replacement, to help
|
||||||
|
# generalize beyond corpora that don't have smart-quotes, or only have smart
|
||||||
|
# quotes, etc.
|
||||||
|
augmenter = {"@augmenters": "spacy.orth_variants.v1", "level": 0.1, "lower": 0.5}
|
||||||
|
|
||||||
[corpora.dev]
|
[corpora.dev]
|
||||||
@readers = "spacy.Corpus.v1"
|
@readers = "spacy.Corpus.v1"
|
||||||
|
|
|
@ -4,7 +4,7 @@ from spacy.training import biluo_tags_to_spans, iob_to_biluo
|
||||||
from spacy.training import Corpus, docs_to_json
|
from spacy.training import Corpus, docs_to_json
|
||||||
from spacy.training.example import Example
|
from spacy.training.example import Example
|
||||||
from spacy.training.converters import json_to_docs
|
from spacy.training.converters import json_to_docs
|
||||||
from spacy.training.augment import make_orth_variants_example
|
from spacy.training.augment import create_orth_variants_augmenter
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.tokens import Doc, DocBin
|
from spacy.tokens import Doc, DocBin
|
||||||
from spacy.util import get_words_and_spaces, minibatch
|
from spacy.util import get_words_and_spaces, minibatch
|
||||||
|
@ -496,9 +496,8 @@ def test_make_orth_variants(doc):
|
||||||
output_file = tmpdir / "roundtrip.spacy"
|
output_file = tmpdir / "roundtrip.spacy"
|
||||||
DocBin(docs=[doc]).to_disk(output_file)
|
DocBin(docs=[doc]).to_disk(output_file)
|
||||||
# due to randomness, test only that this runs with no errors for now
|
# due to randomness, test only that this runs with no errors for now
|
||||||
reader = Corpus(output_file)
|
reader = Corpus(output_file, augmenter=create_orth_variants_augmenter(level=0.2, lower=0.5))
|
||||||
train_example = next(reader(nlp))
|
train_examples = list(reader(nlp))
|
||||||
make_orth_variants_example(nlp, train_example, orth_variant_level=0.2)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("Outdated")
|
@pytest.mark.skip("Outdated")
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from .corpus import Corpus # noqa: F401
|
from .corpus import Corpus # noqa: F401
|
||||||
from .example import Example, validate_examples # noqa: F401
|
from .example import Example, validate_examples # noqa: F401
|
||||||
from .align import Alignment # noqa: F401
|
from .align import Alignment # noqa: F401
|
||||||
|
from .augment import dont_augment, orth_variants_augmenter # noqa: F401
|
||||||
from .iob_utils import iob_to_biluo, biluo_to_iob # noqa: F401
|
from .iob_utils import iob_to_biluo, biluo_to_iob # noqa: F401
|
||||||
from .iob_utils import offsets_to_biluo_tags, biluo_tags_to_offsets # noqa: F401
|
from .iob_utils import offsets_to_biluo_tags, biluo_tags_to_offsets # noqa: F401
|
||||||
from .iob_utils import biluo_tags_to_spans, tags_to_entities # noqa: F401
|
from .iob_utils import biluo_tags_to_spans, tags_to_entities # noqa: F401
|
||||||
|
|
|
@ -1,30 +1,50 @@
|
||||||
|
from typing import Callable
|
||||||
import random
|
import random
|
||||||
import itertools
|
import itertools
|
||||||
|
import copy
|
||||||
|
from functools import partial
|
||||||
|
from ..util import registry
|
||||||
|
|
||||||
|
|
||||||
def make_orth_variants_example(nlp, example, orth_variant_level=0.0): # TODO: naming
|
@registry.augmenters("spacy.dont_augment.v1")
|
||||||
raw_text = example.text
|
def create_null_augmenter():
|
||||||
orig_dict = example.to_dict()
|
return dont_augment
|
||||||
variant_text, variant_token_annot = make_orth_variants(
|
|
||||||
nlp, raw_text, orig_dict["token_annotation"], orth_variant_level
|
|
||||||
)
|
|
||||||
doc = nlp.make_doc(variant_text)
|
|
||||||
orig_dict["token_annotation"] = variant_token_annot
|
|
||||||
return example.from_dict(doc, orig_dict)
|
|
||||||
|
|
||||||
|
|
||||||
def make_orth_variants(nlp, raw_text, orig_token_dict, orth_variant_level=0.0):
|
@registry.augmenters("spacy.orth_variants.v1")
|
||||||
if random.random() >= orth_variant_level:
|
def create_orth_variants_augmenter(level: float, lower: float) -> Callable:
|
||||||
return raw_text, orig_token_dict
|
"""Create a data augmentation callback that uses orth-variant replacement.
|
||||||
if not orig_token_dict:
|
The callback can be added to a corpus or other data iterator during training.
|
||||||
return raw_text, orig_token_dict
|
"""
|
||||||
raw = raw_text
|
return partial(orth_variants_augmenter, level=level, lower=lower)
|
||||||
token_dict = orig_token_dict
|
|
||||||
lower = False
|
|
||||||
if random.random() >= 0.5:
|
def dont_augment(nlp, example):
|
||||||
lower = True
|
yield example
|
||||||
if raw is not None:
|
|
||||||
raw = raw.lower()
|
|
||||||
|
def orth_variants_augmenter(nlp, example, *, level: float = 0.0, lower: float=0.0):
|
||||||
|
if random.random() >= level:
|
||||||
|
yield example
|
||||||
|
else:
|
||||||
|
raw_text = example.text
|
||||||
|
orig_dict = example.to_dict()
|
||||||
|
if not orig_dict["token_annotation"]:
|
||||||
|
yield example
|
||||||
|
else:
|
||||||
|
variant_text, variant_token_annot = make_orth_variants(
|
||||||
|
nlp,
|
||||||
|
raw_text,
|
||||||
|
orig_dict["token_annotation"],
|
||||||
|
lower=raw_text is not None and random.random() < lower
|
||||||
|
)
|
||||||
|
doc = nlp.make_doc(variant_text)
|
||||||
|
orig_dict["token_annotation"] = variant_token_annot
|
||||||
|
yield example.from_dict(doc, orig_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def make_orth_variants(nlp, raw, token_dict, *, lower: bool=False):
|
||||||
|
orig_token_dict = copy.deepcopy(token_dict)
|
||||||
orth_variants = nlp.vocab.lookups.get_table("orth_variants", {})
|
orth_variants = nlp.vocab.lookups.get_table("orth_variants", {})
|
||||||
ndsv = orth_variants.get("single", [])
|
ndsv = orth_variants.get("single", [])
|
||||||
ndpv = orth_variants.get("paired", [])
|
ndpv = orth_variants.get("paired", [])
|
||||||
|
@ -103,7 +123,7 @@ def make_orth_variants(nlp, raw_text, orig_token_dict, orth_variant_level=0.0):
|
||||||
# something went wrong, abort
|
# something went wrong, abort
|
||||||
# (add a warning message?)
|
# (add a warning message?)
|
||||||
if not match_found:
|
if not match_found:
|
||||||
return raw_text, orig_token_dict
|
return raw, orig_token_dict
|
||||||
# add following whitespace
|
# add following whitespace
|
||||||
while raw_idx < len(raw) and raw[raw_idx].isspace():
|
while raw_idx < len(raw) and raw[raw_idx].isspace():
|
||||||
variant_raw += raw[raw_idx]
|
variant_raw += raw[raw_idx]
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Union, List, Iterable, Iterator, TYPE_CHECKING, Callable
|
from typing import Union, List, Iterable, Iterator, TYPE_CHECKING, Callable
|
||||||
|
from typing import Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import srsly
|
import srsly
|
||||||
|
|
||||||
from .. import util
|
from .. import util
|
||||||
|
from .augment import dont_augment
|
||||||
from .example import Example
|
from .example import Example
|
||||||
from ..errors import Warnings
|
from ..errors import Warnings
|
||||||
from ..tokens import DocBin, Doc
|
from ..tokens import DocBin, Doc
|
||||||
|
@ -18,9 +20,19 @@ FILE_TYPE = ".spacy"
|
||||||
|
|
||||||
@util.registry.readers("spacy.Corpus.v1")
|
@util.registry.readers("spacy.Corpus.v1")
|
||||||
def create_docbin_reader(
|
def create_docbin_reader(
|
||||||
path: Path, gold_preproc: bool, max_length: int = 0, limit: int = 0
|
path: Path,
|
||||||
|
gold_preproc: bool,
|
||||||
|
max_length: int = 0,
|
||||||
|
limit: int = 0,
|
||||||
|
augmenter: Optional[Callable] = None,
|
||||||
) -> Callable[["Language"], Iterable[Example]]:
|
) -> Callable[["Language"], Iterable[Example]]:
|
||||||
return Corpus(path, gold_preproc=gold_preproc, max_length=max_length, limit=limit)
|
return Corpus(
|
||||||
|
path,
|
||||||
|
gold_preproc=gold_preproc,
|
||||||
|
max_length=max_length,
|
||||||
|
limit=limit,
|
||||||
|
augmenter=augmenter,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@util.registry.readers("spacy.JsonlReader.v1")
|
@util.registry.readers("spacy.JsonlReader.v1")
|
||||||
|
@ -70,6 +82,8 @@ class Corpus:
|
||||||
0, which indicates no limit.
|
0, which indicates no limit.
|
||||||
limit (int): Limit corpus to a subset of examples, e.g. for debugging.
|
limit (int): Limit corpus to a subset of examples, e.g. for debugging.
|
||||||
Defaults to 0, which indicates no limit.
|
Defaults to 0, which indicates no limit.
|
||||||
|
augment (Callable[Example, Iterable[Example]]): Optional data augmentation
|
||||||
|
function, to extrapolate additional examples from your annotations.
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/corpus
|
DOCS: https://nightly.spacy.io/api/corpus
|
||||||
"""
|
"""
|
||||||
|
@ -81,11 +95,13 @@ class Corpus:
|
||||||
limit: int = 0,
|
limit: int = 0,
|
||||||
gold_preproc: bool = False,
|
gold_preproc: bool = False,
|
||||||
max_length: int = 0,
|
max_length: int = 0,
|
||||||
|
augmenter: Optional[Callable] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.path = util.ensure_path(path)
|
self.path = util.ensure_path(path)
|
||||||
self.gold_preproc = gold_preproc
|
self.gold_preproc = gold_preproc
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.limit = limit
|
self.limit = limit
|
||||||
|
self.augmenter = augmenter if augmenter is not None else dont_augment
|
||||||
|
|
||||||
def __call__(self, nlp: "Language") -> Iterator[Example]:
|
def __call__(self, nlp: "Language") -> Iterator[Example]:
|
||||||
"""Yield examples from the data.
|
"""Yield examples from the data.
|
||||||
|
@ -100,7 +116,9 @@ class Corpus:
|
||||||
examples = self.make_examples_gold_preproc(nlp, ref_docs)
|
examples = self.make_examples_gold_preproc(nlp, ref_docs)
|
||||||
else:
|
else:
|
||||||
examples = self.make_examples(nlp, ref_docs)
|
examples = self.make_examples(nlp, ref_docs)
|
||||||
yield from examples
|
for real_eg in examples:
|
||||||
|
for augmented_eg in self.augmenter(nlp, real_eg):
|
||||||
|
yield augmented_eg
|
||||||
|
|
||||||
def _make_example(
|
def _make_example(
|
||||||
self, nlp: "Language", reference: Doc, gold_preproc: bool
|
self, nlp: "Language", reference: Doc, gold_preproc: bool
|
||||||
|
|
|
@ -81,6 +81,7 @@ class registry(thinc.registry):
|
||||||
callbacks = catalogue.create("spacy", "callbacks")
|
callbacks = catalogue.create("spacy", "callbacks")
|
||||||
batchers = catalogue.create("spacy", "batchers", entry_points=True)
|
batchers = catalogue.create("spacy", "batchers", entry_points=True)
|
||||||
readers = catalogue.create("spacy", "readers", entry_points=True)
|
readers = catalogue.create("spacy", "readers", entry_points=True)
|
||||||
|
augmenters = catalogue.create("spacy", "augmenters", entry_points=True)
|
||||||
loggers = catalogue.create("spacy", "loggers", entry_points=True)
|
loggers = catalogue.create("spacy", "loggers", entry_points=True)
|
||||||
# These are factories registered via third-party packages and the
|
# These are factories registered via third-party packages and the
|
||||||
# spacy_factories entry point. This registry only exists so we can easily
|
# spacy_factories entry point. This registry only exists so we can easily
|
||||||
|
|
|
@ -74,6 +74,7 @@ train/test skew.
|
||||||
| `gold_preproc` | Whether to set up the Example object with gold-standard sentences and tokens for the predictions. Defaults to `False`. ~~bool~~ |
|
| `gold_preproc` | Whether to set up the Example object with gold-standard sentences and tokens for the predictions. Defaults to `False`. ~~bool~~ |
|
||||||
| `max_length` | Maximum document length. Longer documents will be split into sentences, if sentence boundaries are available. Defaults to `0` for no limit. ~~int~~ |
|
| `max_length` | Maximum document length. Longer documents will be split into sentences, if sentence boundaries are available. Defaults to `0` for no limit. ~~int~~ |
|
||||||
| `limit` | Limit corpus to a subset of examples, e.g. for debugging. Defaults to `0` for no limit. ~~int~~ |
|
| `limit` | Limit corpus to a subset of examples, e.g. for debugging. Defaults to `0` for no limit. ~~int~~ |
|
||||||
|
| `augmenter` | Optional data augmentation callback. ~~Callable[[Language, Example], Iterable[Example]]~~
|
||||||
|
|
||||||
## Corpus.\_\_call\_\_ {#call tag="method"}
|
## Corpus.\_\_call\_\_ {#call tag="method"}
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ menu:
|
||||||
- ['Introduction', 'basics']
|
- ['Introduction', 'basics']
|
||||||
- ['Quickstart', 'quickstart']
|
- ['Quickstart', 'quickstart']
|
||||||
- ['Config System', 'config']
|
- ['Config System', 'config']
|
||||||
|
<!-- - ['Data Utilities', 'data'] -->
|
||||||
- ['Custom Functions', 'custom-functions']
|
- ['Custom Functions', 'custom-functions']
|
||||||
- ['Parallel Training', 'parallel-training']
|
- ['Parallel Training', 'parallel-training']
|
||||||
- ['Internal API', 'api']
|
- ['Internal API', 'api']
|
||||||
|
@ -505,6 +506,16 @@ still look good.
|
||||||
|
|
||||||
</Accordion>
|
</Accordion>
|
||||||
|
|
||||||
|
<!--
|
||||||
|
## Data Utilities {#data-utilities}
|
||||||
|
|
||||||
|
* spacy convert
|
||||||
|
* The [corpora] block
|
||||||
|
* Custom corpus class
|
||||||
|
* Minibatching
|
||||||
|
* Data augmentation
|
||||||
|
-->
|
||||||
|
|
||||||
## Custom Functions {#custom-functions}
|
## Custom Functions {#custom-functions}
|
||||||
|
|
||||||
Registered functions in the training config files can refer to built-in
|
Registered functions in the training config files can refer to built-in
|
||||||
|
|
Loading…
Reference in New Issue
Block a user