mirror of
synced 2025-03-15 23:52:30 +03:00
Add whitespace and combined augmenters (#10170)
Add whitespace augmenter that inserts a single whitespace token into a doc containing annotation used in core trained pipelines. Add a combined augmenter that handles lowercasing, orth variants and whitespace augmentation.
This commit is contained in:
@ -1,9 +1,11 @@
import pytest
from spacy.training import Corpus
from spacy.pipeline._parser_internals.nonproj import contains_cycle
from spacy.training import Corpus, Example
from spacy.training.augment import create_orth_variants_augmenter
from spacy.training.augment import create_lower_casing_augmenter
from spacy.training.augment import make_whitespace_variant
from spacy.lang.en import English
from spacy.tokens import DocBin, Doc
from spacy.tokens import DocBin, Doc, Span
from contextlib import contextmanager
import random
@ -153,3 +155,84 @@ def test_custom_data_augmentation(nlp, doc):
ents = [(e.start, e.end, e.label) for e in doc.ents]
assert [(e.start, e.end, e.label) for e in corpus[0].reference.ents] == ents
assert [(e.start, e.end, e.label) for e in corpus[1].reference.ents] == ents
def test_make_whitespace_variant(nlp):
# fmt: off
text = "They flew to New York City.\nThen they drove to Washington, D.C."
words = ["They", "flew", "to", "New", "York", "City", ".", "\n", "Then", "they", "drove", "to", "Washington", ",", "D.C."]
spaces = [True, True, True, True, True, False, False, False, True, True, True, True, False, True, False]
tags = ["PRP", "VBD", "IN", "NNP", "NNP", "NNP", ".", "_SP", "RB", "PRP", "VBD", "IN", "NNP", ",", "NNP"]
lemmas = ["they", "fly", "to", "New", "York", "City", ".", "\n", "then", "they", "drive", "to", "Washington", ",", "D.C."]
heads = [1, 1, 1, 4, 5, 2, 1, 10, 10, 10, 10, 10, 11, 12, 12]
deps = ["nsubj", "ROOT", "prep", "compound", "compound", "pobj", "punct", "dep", "advmod", "nsubj", "ROOT", "prep", "pobj", "punct", "appos"]
ents = ["O", "O", "O", "B-GPE", "I-GPE", "I-GPE", "O", "O", "O", "O", "O", "O", "B-GPE", "O", "B-GPE"]
# fmt: on
doc = Doc(
assert doc.text == text
example = Example(nlp.make_doc(text), doc)
# whitespace is only added internally in entity spans
mod_ex = make_whitespace_variant(nlp, example, " ", 3)
assert mod_ex.reference.ents[0].text == "New York City"
mod_ex = make_whitespace_variant(nlp, example, " ", 4)
assert mod_ex.reference.ents[0].text == "New York City"
mod_ex = make_whitespace_variant(nlp, example, " ", 5)
assert mod_ex.reference.ents[0].text == "New York City"
mod_ex = make_whitespace_variant(nlp, example, " ", 6)
assert mod_ex.reference.ents[0].text == "New York City"
# add a space at every possible position
for i in range(len(doc) + 1):
mod_ex = make_whitespace_variant(nlp, example, " ", i)
assert mod_ex.reference[i].is_space
# adds annotation when the doc contains at least partial annotation
assert [t.tag_ for t in mod_ex.reference] == tags[:i] + ["_SP"] + tags[i:]
assert [t.lemma_ for t in mod_ex.reference] == lemmas[:i] + [" "] + lemmas[i:]
assert [t.dep_ for t in mod_ex.reference] == deps[:i] + ["dep"] + deps[i:]
# does not add partial annotation if doc does not contain this feature
assert not mod_ex.reference.has_annotation("POS")
assert not mod_ex.reference.has_annotation("MORPH")
# produces well-formed trees
assert not contains_cycle([t.head.i for t in mod_ex.reference])
assert len(list(doc.sents)) == 2
if i == 0:
assert mod_ex.reference[i].head.i == 1
assert mod_ex.reference[i].head.i == i - 1
# adding another space also produces well-formed trees
for j in (3, 8, 10):
mod_ex2 = make_whitespace_variant(nlp, mod_ex, "\t\t\n", j)
assert not contains_cycle([t.head.i for t in mod_ex2.reference])
assert len(list(doc.sents)) == 2
assert mod_ex2.reference[j].head.i == j - 1
# entities are well-formed
assert len(doc.ents) == len(mod_ex.reference.ents)
for ent in mod_ex.reference.ents:
assert not ent[0].is_space
assert not ent[-1].is_space
# no modifications if:
# partial dependencies
example.reference[0].dep_ = ""
mod_ex = make_whitespace_variant(nlp, example, " ", 5)
assert mod_ex.text == example.reference.text
example.reference[0].dep_ = "nsubj" # reset
# spans
example.reference.spans["spans"] = [example.reference[0:5]]
mod_ex = make_whitespace_variant(nlp, example, " ", 5)
assert mod_ex.text == example.reference.text
del example.reference.spans["spans"] # reset
# links
example.reference.ents = [Span(doc, 0, 2, label="ENT", kb_id="Q123")]
mod_ex = make_whitespace_variant(nlp, example, " ", 5)
assert mod_ex.text == example.reference.text
@ -1,4 +1,5 @@
from typing import Callable, Iterator, Dict, List, Tuple, TYPE_CHECKING
from typing import Optional
import random
import itertools
from functools import partial
@ -11,32 +12,87 @@ if TYPE_CHECKING:
from ..language import Language # noqa: F401
class OrthVariantsSingle(BaseModel):
tags: List[StrictStr]
variants: List[StrictStr]
def create_combined_augmenter(
lower_level: float,
orth_level: float,
orth_variants: Optional[Dict[str, List[Dict]]],
whitespace_level: float,
whitespace_per_token: float,
whitespace_variants: Optional[List[str]],
) -> Callable[["Language", Example], Iterator[Example]]:
"""Create a data augmentation callback that uses orth-variant replacement.
The callback can be added to a corpus or other data iterator during training.
lower_level (float): The percentage of texts that will be lowercased.
orth_level (float): The percentage of texts that will be augmented.
orth_variants (Optional[Dict[str, List[Dict]]]): A dictionary containing the
single and paired orth variants. Typically loaded from a JSON file.
whitespace_level (float): The percentage of texts that will have whitespace
tokens inserted.
whitespace_per_token (float): The number of whitespace tokens to insert in
the modified doc as a percentage of the doc length.
whitespace_variants (Optional[List[str]]): The whitespace token texts.
RETURNS (Callable[[Language, Example], Iterator[Example]]): The augmenter.
return partial(
class OrthVariantsPaired(BaseModel):
tags: List[StrictStr]
variants: List[List[StrictStr]]
class OrthVariants(BaseModel):
paired: List[OrthVariantsPaired] = []
single: List[OrthVariantsSingle] = []
def combined_augmenter(
nlp: "Language",
example: Example,
lower_level: float = 0.0,
orth_level: float = 0.0,
orth_variants: Optional[Dict[str, List[Dict]]] = None,
whitespace_level: float = 0.0,
whitespace_per_token: float = 0.0,
whitespace_variants: Optional[List[str]] = None,
) -> Iterator[Example]:
if random.random() < lower_level:
example = make_lowercase_variant(nlp, example)
if orth_variants and random.random() < orth_level:
raw_text = example.text
orig_dict = example.to_dict()
variant_text, variant_token_annot = make_orth_variants(
orig_dict["token_annotation"] = variant_token_annot
example = example.from_dict(nlp.make_doc(variant_text), orig_dict)
if whitespace_variants and random.random() < whitespace_level:
for _ in range(int(len(example.reference) * whitespace_per_token)):
example = make_whitespace_variant(
random.randrange(0, len(example.reference)),
yield example
def create_orth_variants_augmenter(
level: float, lower: float, orth_variants: OrthVariants
level: float, lower: float, orth_variants: Dict[str, List[Dict]]
) -> Callable[["Language", Example], Iterator[Example]]:
"""Create a data augmentation callback that uses orth-variant replacement.
The callback can be added to a corpus or other data iterator during training.
level (float): The percentage of texts that will be augmented.
lower (float): The percentage of texts that will be lowercased.
orth_variants (Dict[str, dict]): A dictionary containing the single and
paired orth variants. Typically loaded from a JSON file.
orth_variants (Dict[str, List[Dict]]): A dictionary containing
the single and paired orth variants. Typically loaded from a JSON file.
RETURNS (Callable[[Language, Example], Iterator[Example]]): The augmenter.
return partial(
@ -67,16 +123,20 @@ def lower_casing_augmenter(
if random.random() >= level:
yield example
example_dict = example.to_dict()
doc = nlp.make_doc(example.text.lower())
example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in example.reference]
yield example.from_dict(doc, example_dict)
yield make_lowercase_variant(nlp, example)
def make_lowercase_variant(nlp: "Language", example: Example):
example_dict = example.to_dict()
doc = nlp.make_doc(example.text.lower())
example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in example.reference]
return example.from_dict(doc, example_dict)
def orth_variants_augmenter(
nlp: "Language",
example: Example,
orth_variants: Dict,
orth_variants: Dict[str, List[Dict]],
level: float = 0.0,
lower: float = 0.0,
@ -148,10 +208,132 @@ def make_orth_variants(
pair_idx = pair.index(words[word_idx])
words[word_idx] = punct_choices[punct_idx][pair_idx]
token_dict["ORTH"] = words
# construct modified raw text from words and spaces
raw = construct_modified_raw_text(token_dict)
return raw, token_dict
def make_whitespace_variant(
nlp: "Language",
example: Example,
whitespace: str,
position: int,
) -> Example:
"""Insert the whitespace token at the specified token offset in the doc.
This is primarily intended for v2-compatible training data that doesn't
include links or spans. If the document includes links, spans, or partial
dependency annotation, it is returned without modifications.
The augmentation follows the basics of the v2 space attachment policy, but
without a distinction between "real" and other tokens, so space tokens
may be attached to space tokens:
- at the beginning of a sentence attach the space token to the following
- otherwise attach the space token to the preceding token
The augmenter does not attempt to consolidate adjacent whitespace in the
same way that the tokenizer would.
The following annotation is used for the space token:
TAG: "_SP"
DEP: "dep"
The annotation for each attribute is only set for the space token if there
is already at least partial annotation for that attribute in the original
RETURNS (Example): Example with one additional space token.
example_dict = example.to_dict()
doc_dict = example_dict.get("doc_annotation", {})
token_dict = example_dict.get("token_annotation", {})
# returned unmodified if:
# - doc is empty
# - words are not defined
# - links are defined (only character-based offsets, which is more a quirk
# of Example.to_dict than a technical constraint)
# - spans are defined
# - there are partial dependencies
if (
len(example.reference) == 0
or "ORTH" not in token_dict
or len(doc_dict.get("links", [])) > 0
or len(example.reference.spans) > 0
or (
and not example.reference.has_annotation("DEP", require_complete=True)
return example
words = token_dict.get("ORTH", [])
length = len(words)
assert 0 <= position <= length
if example.reference.has_annotation("ENT_TYPE"):
# I-ENTITY if between B/I-ENTITY and I/L-ENTITY otherwise O
entity = "O"
if position > 1 and position < length:
ent_prev = doc_dict["entities"][position - 1]
ent_next = doc_dict["entities"][position]
if "-" in ent_prev and "-" in ent_next:
ent_iob_prev = ent_prev.split("-")[0]
ent_type_prev = ent_prev.split("-", 1)[1]
ent_iob_next = ent_next.split("-")[0]
ent_type_next = ent_next.split("-", 1)[1]
if (
ent_iob_prev in ("B", "I")
and ent_iob_next in ("I", "L")
and ent_type_prev == ent_type_next
entity = f"I-{ent_type_prev}"
doc_dict["entities"].insert(position, entity)
del doc_dict["entities"]
token_dict["ORTH"].insert(position, whitespace)
token_dict["SPACY"].insert(position, False)
if example.reference.has_annotation("TAG"):
token_dict["TAG"].insert(position, "_SP")
del token_dict["TAG"]
if example.reference.has_annotation("LEMMA"):
token_dict["LEMMA"].insert(position, whitespace)
del token_dict["LEMMA"]
if example.reference.has_annotation("POS"):
token_dict["POS"].insert(position, "SPACE")
del token_dict["POS"]
if example.reference.has_annotation("MORPH"):
token_dict["MORPH"].insert(position, "")
del token_dict["MORPH"]
if example.reference.has_annotation("DEP", require_complete=True):
if position == 0:
token_dict["HEAD"].insert(position, 0)
token_dict["HEAD"].insert(position, position - 1)
for i in range(len(token_dict["HEAD"])):
if token_dict["HEAD"][i] >= position:
token_dict["HEAD"][i] += 1
token_dict["DEP"].insert(position, "dep")
del token_dict["HEAD"]
del token_dict["DEP"]
if example.reference.has_annotation("SENT_START"):
token_dict["SENT_START"].insert(position, False)
del token_dict["SENT_START"]
raw = construct_modified_raw_text(token_dict)
return Example.from_dict(nlp.make_doc(raw), example_dict)
def construct_modified_raw_text(token_dict):
"""Construct modified raw text from words and spaces."""
raw = ""
for orth, spacy in zip(token_dict["ORTH"], token_dict["SPACY"]):
raw += orth
if spacy:
raw += " "
return raw, token_dict
return raw
Reference in New Issue
Block a user