Add whitespace and combined augmenters (#10170)

Add whitespace augmenter that inserts a single whitespace token into a
doc containing annotation used in core trained pipelines.

Add a combined augmenter that handles lowercasing, orth variants and
whitespace augmentation.
This commit is contained in:
Adriane Boyd 2022-02-17 15:54:09 +01:00 committed by GitHub
parent aa93b471a1
commit 28ba31e793
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 288 additions and 23 deletions

View File

@ -1,9 +1,11 @@
import pytest import pytest
from spacy.training import Corpus from spacy.pipeline._parser_internals.nonproj import contains_cycle
from spacy.training import Corpus, Example
from spacy.training.augment import create_orth_variants_augmenter from spacy.training.augment import create_orth_variants_augmenter
from spacy.training.augment import create_lower_casing_augmenter from spacy.training.augment import create_lower_casing_augmenter
from spacy.training.augment import make_whitespace_variant
from spacy.lang.en import English from spacy.lang.en import English
from spacy.tokens import DocBin, Doc from spacy.tokens import DocBin, Doc, Span
from contextlib import contextmanager from contextlib import contextmanager
import random import random
@ -153,3 +155,84 @@ def test_custom_data_augmentation(nlp, doc):
ents = [(e.start, e.end, e.label) for e in doc.ents] ents = [(e.start, e.end, e.label) for e in doc.ents]
assert [(e.start, e.end, e.label) for e in corpus[0].reference.ents] == ents assert [(e.start, e.end, e.label) for e in corpus[0].reference.ents] == ents
assert [(e.start, e.end, e.label) for e in corpus[1].reference.ents] == ents assert [(e.start, e.end, e.label) for e in corpus[1].reference.ents] == ents
def test_make_whitespace_variant(nlp):
# fmt: off
text = "They flew to New York City.\nThen they drove to Washington, D.C."
words = ["They", "flew", "to", "New", "York", "City", ".", "\n", "Then", "they", "drove", "to", "Washington", ",", "D.C."]
spaces = [True, True, True, True, True, False, False, False, True, True, True, True, False, True, False]
tags = ["PRP", "VBD", "IN", "NNP", "NNP", "NNP", ".", "_SP", "RB", "PRP", "VBD", "IN", "NNP", ",", "NNP"]
lemmas = ["they", "fly", "to", "New", "York", "City", ".", "\n", "then", "they", "drive", "to", "Washington", ",", "D.C."]
heads = [1, 1, 1, 4, 5, 2, 1, 10, 10, 10, 10, 10, 11, 12, 12]
deps = ["nsubj", "ROOT", "prep", "compound", "compound", "pobj", "punct", "dep", "advmod", "nsubj", "ROOT", "prep", "pobj", "punct", "appos"]
ents = ["O", "O", "O", "B-GPE", "I-GPE", "I-GPE", "O", "O", "O", "O", "O", "O", "B-GPE", "O", "B-GPE"]
# fmt: on
doc = Doc(
nlp.vocab,
words=words,
spaces=spaces,
tags=tags,
lemmas=lemmas,
heads=heads,
deps=deps,
ents=ents,
)
assert doc.text == text
example = Example(nlp.make_doc(text), doc)
# whitespace is only added internally in entity spans
mod_ex = make_whitespace_variant(nlp, example, " ", 3)
assert mod_ex.reference.ents[0].text == "New York City"
mod_ex = make_whitespace_variant(nlp, example, " ", 4)
assert mod_ex.reference.ents[0].text == "New York City"
mod_ex = make_whitespace_variant(nlp, example, " ", 5)
assert mod_ex.reference.ents[0].text == "New York City"
mod_ex = make_whitespace_variant(nlp, example, " ", 6)
assert mod_ex.reference.ents[0].text == "New York City"
# add a space at every possible position
for i in range(len(doc) + 1):
mod_ex = make_whitespace_variant(nlp, example, " ", i)
assert mod_ex.reference[i].is_space
# adds annotation when the doc contains at least partial annotation
assert [t.tag_ for t in mod_ex.reference] == tags[:i] + ["_SP"] + tags[i:]
assert [t.lemma_ for t in mod_ex.reference] == lemmas[:i] + [" "] + lemmas[i:]
assert [t.dep_ for t in mod_ex.reference] == deps[:i] + ["dep"] + deps[i:]
# does not add partial annotation if doc does not contain this feature
assert not mod_ex.reference.has_annotation("POS")
assert not mod_ex.reference.has_annotation("MORPH")
# produces well-formed trees
assert not contains_cycle([t.head.i for t in mod_ex.reference])
assert len(list(doc.sents)) == 2
if i == 0:
assert mod_ex.reference[i].head.i == 1
else:
assert mod_ex.reference[i].head.i == i - 1
# adding another space also produces well-formed trees
for j in (3, 8, 10):
mod_ex2 = make_whitespace_variant(nlp, mod_ex, "\t\t\n", j)
assert not contains_cycle([t.head.i for t in mod_ex2.reference])
assert len(list(doc.sents)) == 2
assert mod_ex2.reference[j].head.i == j - 1
# entities are well-formed
assert len(doc.ents) == len(mod_ex.reference.ents)
for ent in mod_ex.reference.ents:
assert not ent[0].is_space
assert not ent[-1].is_space
# no modifications if:
# partial dependencies
example.reference[0].dep_ = ""
mod_ex = make_whitespace_variant(nlp, example, " ", 5)
assert mod_ex.text == example.reference.text
example.reference[0].dep_ = "nsubj" # reset
# spans
example.reference.spans["spans"] = [example.reference[0:5]]
mod_ex = make_whitespace_variant(nlp, example, " ", 5)
assert mod_ex.text == example.reference.text
del example.reference.spans["spans"] # reset
# links
example.reference.ents = [Span(doc, 0, 2, label="ENT", kb_id="Q123")]
mod_ex = make_whitespace_variant(nlp, example, " ", 5)
assert mod_ex.text == example.reference.text

View File

@ -1,4 +1,5 @@
from typing import Callable, Iterator, Dict, List, Tuple, TYPE_CHECKING from typing import Callable, Iterator, Dict, List, Tuple, TYPE_CHECKING
from typing import Optional
import random import random
import itertools import itertools
from functools import partial from functools import partial
@ -11,32 +12,87 @@ if TYPE_CHECKING:
from ..language import Language # noqa: F401 from ..language import Language # noqa: F401
class OrthVariantsSingle(BaseModel): @registry.augmenters("spacy.combined_augmenter.v1")
tags: List[StrictStr] def create_combined_augmenter(
variants: List[StrictStr] lower_level: float,
orth_level: float,
orth_variants: Optional[Dict[str, List[Dict]]],
whitespace_level: float,
whitespace_per_token: float,
whitespace_variants: Optional[List[str]],
) -> Callable[["Language", Example], Iterator[Example]]:
"""Create a data augmentation callback that uses orth-variant replacement.
The callback can be added to a corpus or other data iterator during training.
lower_level (float): The percentage of texts that will be lowercased.
orth_level (float): The percentage of texts that will be augmented.
orth_variants (Optional[Dict[str, List[Dict]]]): A dictionary containing the
single and paired orth variants. Typically loaded from a JSON file.
whitespace_level (float): The percentage of texts that will have whitespace
tokens inserted.
whitespace_per_token (float): The number of whitespace tokens to insert in
the modified doc as a percentage of the doc length.
whitespace_variants (Optional[List[str]]): The whitespace token texts.
RETURNS (Callable[[Language, Example], Iterator[Example]]): The augmenter.
"""
return partial(
combined_augmenter,
lower_level=lower_level,
orth_level=orth_level,
orth_variants=orth_variants,
whitespace_level=whitespace_level,
whitespace_per_token=whitespace_per_token,
whitespace_variants=whitespace_variants,
)
class OrthVariantsPaired(BaseModel): def combined_augmenter(
tags: List[StrictStr] nlp: "Language",
variants: List[List[StrictStr]] example: Example,
*,
lower_level: float = 0.0,
class OrthVariants(BaseModel): orth_level: float = 0.0,
paired: List[OrthVariantsPaired] = [] orth_variants: Optional[Dict[str, List[Dict]]] = None,
single: List[OrthVariantsSingle] = [] whitespace_level: float = 0.0,
whitespace_per_token: float = 0.0,
whitespace_variants: Optional[List[str]] = None,
) -> Iterator[Example]:
if random.random() < lower_level:
example = make_lowercase_variant(nlp, example)
if orth_variants and random.random() < orth_level:
raw_text = example.text
orig_dict = example.to_dict()
variant_text, variant_token_annot = make_orth_variants(
nlp,
raw_text,
orig_dict["token_annotation"],
orth_variants,
lower=False,
)
orig_dict["token_annotation"] = variant_token_annot
example = example.from_dict(nlp.make_doc(variant_text), orig_dict)
if whitespace_variants and random.random() < whitespace_level:
for _ in range(int(len(example.reference) * whitespace_per_token)):
example = make_whitespace_variant(
nlp,
example,
random.choice(whitespace_variants),
random.randrange(0, len(example.reference)),
)
yield example
@registry.augmenters("spacy.orth_variants.v1") @registry.augmenters("spacy.orth_variants.v1")
def create_orth_variants_augmenter( def create_orth_variants_augmenter(
level: float, lower: float, orth_variants: OrthVariants level: float, lower: float, orth_variants: Dict[str, List[Dict]]
) -> Callable[["Language", Example], Iterator[Example]]: ) -> Callable[["Language", Example], Iterator[Example]]:
"""Create a data augmentation callback that uses orth-variant replacement. """Create a data augmentation callback that uses orth-variant replacement.
The callback can be added to a corpus or other data iterator during training. The callback can be added to a corpus or other data iterator during training.
level (float): The percentage of texts that will be augmented. level (float): The percentage of texts that will be augmented.
lower (float): The percentage of texts that will be lowercased. lower (float): The percentage of texts that will be lowercased.
orth_variants (Dict[str, dict]): A dictionary containing the single and orth_variants (Dict[str, List[Dict]]): A dictionary containing
paired orth variants. Typically loaded from a JSON file. the single and paired orth variants. Typically loaded from a JSON file.
RETURNS (Callable[[Language, Example], Iterator[Example]]): The augmenter. RETURNS (Callable[[Language, Example], Iterator[Example]]): The augmenter.
""" """
return partial( return partial(
@ -67,16 +123,20 @@ def lower_casing_augmenter(
if random.random() >= level: if random.random() >= level:
yield example yield example
else: else:
example_dict = example.to_dict() yield make_lowercase_variant(nlp, example)
doc = nlp.make_doc(example.text.lower())
example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in example.reference]
yield example.from_dict(doc, example_dict) def make_lowercase_variant(nlp: "Language", example: Example):
example_dict = example.to_dict()
doc = nlp.make_doc(example.text.lower())
example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in example.reference]
return example.from_dict(doc, example_dict)
def orth_variants_augmenter( def orth_variants_augmenter(
nlp: "Language", nlp: "Language",
example: Example, example: Example,
orth_variants: Dict, orth_variants: Dict[str, List[Dict]],
*, *,
level: float = 0.0, level: float = 0.0,
lower: float = 0.0, lower: float = 0.0,
@ -148,10 +208,132 @@ def make_orth_variants(
pair_idx = pair.index(words[word_idx]) pair_idx = pair.index(words[word_idx])
words[word_idx] = punct_choices[punct_idx][pair_idx] words[word_idx] = punct_choices[punct_idx][pair_idx]
token_dict["ORTH"] = words token_dict["ORTH"] = words
# construct modified raw text from words and spaces raw = construct_modified_raw_text(token_dict)
return raw, token_dict
def make_whitespace_variant(
nlp: "Language",
example: Example,
whitespace: str,
position: int,
) -> Example:
"""Insert the whitespace token at the specified token offset in the doc.
This is primarily intended for v2-compatible training data that doesn't
include links or spans. If the document includes links, spans, or partial
dependency annotation, it is returned without modifications.
The augmentation follows the basics of the v2 space attachment policy, but
without a distinction between "real" and other tokens, so space tokens
may be attached to space tokens:
- at the beginning of a sentence attach the space token to the following
token
- otherwise attach the space token to the preceding token
The augmenter does not attempt to consolidate adjacent whitespace in the
same way that the tokenizer would.
The following annotation is used for the space token:
TAG: "_SP"
MORPH: ""
POS: "SPACE"
LEMMA: ORTH
DEP: "dep"
SENT_START: False
The annotation for each attribute is only set for the space token if there
is already at least partial annotation for that attribute in the original
example.
RETURNS (Example): Example with one additional space token.
"""
example_dict = example.to_dict()
doc_dict = example_dict.get("doc_annotation", {})
token_dict = example_dict.get("token_annotation", {})
# returned unmodified if:
# - doc is empty
# - words are not defined
# - links are defined (only character-based offsets, which is more a quirk
# of Example.to_dict than a technical constraint)
# - spans are defined
# - there are partial dependencies
if (
len(example.reference) == 0
or "ORTH" not in token_dict
or len(doc_dict.get("links", [])) > 0
or len(example.reference.spans) > 0
or (
example.reference.has_annotation("DEP")
and not example.reference.has_annotation("DEP", require_complete=True)
)
):
return example
words = token_dict.get("ORTH", [])
length = len(words)
assert 0 <= position <= length
if example.reference.has_annotation("ENT_TYPE"):
# I-ENTITY if between B/I-ENTITY and I/L-ENTITY otherwise O
entity = "O"
if position > 1 and position < length:
ent_prev = doc_dict["entities"][position - 1]
ent_next = doc_dict["entities"][position]
if "-" in ent_prev and "-" in ent_next:
ent_iob_prev = ent_prev.split("-")[0]
ent_type_prev = ent_prev.split("-", 1)[1]
ent_iob_next = ent_next.split("-")[0]
ent_type_next = ent_next.split("-", 1)[1]
if (
ent_iob_prev in ("B", "I")
and ent_iob_next in ("I", "L")
and ent_type_prev == ent_type_next
):
entity = f"I-{ent_type_prev}"
doc_dict["entities"].insert(position, entity)
else:
del doc_dict["entities"]
token_dict["ORTH"].insert(position, whitespace)
token_dict["SPACY"].insert(position, False)
if example.reference.has_annotation("TAG"):
token_dict["TAG"].insert(position, "_SP")
else:
del token_dict["TAG"]
if example.reference.has_annotation("LEMMA"):
token_dict["LEMMA"].insert(position, whitespace)
else:
del token_dict["LEMMA"]
if example.reference.has_annotation("POS"):
token_dict["POS"].insert(position, "SPACE")
else:
del token_dict["POS"]
if example.reference.has_annotation("MORPH"):
token_dict["MORPH"].insert(position, "")
else:
del token_dict["MORPH"]
if example.reference.has_annotation("DEP", require_complete=True):
if position == 0:
token_dict["HEAD"].insert(position, 0)
else:
token_dict["HEAD"].insert(position, position - 1)
for i in range(len(token_dict["HEAD"])):
if token_dict["HEAD"][i] >= position:
token_dict["HEAD"][i] += 1
token_dict["DEP"].insert(position, "dep")
else:
del token_dict["HEAD"]
del token_dict["DEP"]
if example.reference.has_annotation("SENT_START"):
token_dict["SENT_START"].insert(position, False)
else:
del token_dict["SENT_START"]
raw = construct_modified_raw_text(token_dict)
return Example.from_dict(nlp.make_doc(raw), example_dict)
def construct_modified_raw_text(token_dict):
"""Construct modified raw text from words and spaces."""
raw = "" raw = ""
for orth, spacy in zip(token_dict["ORTH"], token_dict["SPACY"]): for orth, spacy in zip(token_dict["ORTH"], token_dict["SPACY"]):
raw += orth raw += orth
if spacy: if spacy:
raw += " " raw += " "
return raw, token_dict return raw