Fix lowercase augmentation (#7336)

* Fix aborted/skipped augmentation for `spacy.orth_variants.v1` if
lowercasing was enabled for an example
* Simplify `spacy.orth_variants.v1` for `Example` vs. `GoldParse`
* Preserve reference tokenization in `spacy.lower_case.v1`
This commit is contained in:
Adriane Boyd 2021-03-09 04:02:32 +01:00 committed by GitHub
parent cd70c3cb79
commit 3f3e8110dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 114 additions and 105 deletions

View File

@ -38,19 +38,59 @@ def doc(nlp):
@pytest.mark.filterwarnings("ignore::UserWarning")
def test_make_orth_variants(nlp, doc):
def test_make_orth_variants(nlp):
single = [
{"tags": ["NFP"], "variants": ["", "..."]},
{"tags": [":"], "variants": ["-", "", "", "--", "---", "——"]},
]
# fmt: off
words = ["\n\n", "A", "\t", "B", "a", "b", "", "...", "-", "", "", "--", "---", "——"]
tags = ["_SP", "NN", "\t", "NN", "NN", "NN", "NFP", "NFP", ":", ":", ":", ":", ":", ":"]
# fmt: on
spaces = [True] * len(words)
spaces[0] = False
spaces[2] = False
doc = Doc(nlp.vocab, words=words, spaces=spaces, tags=tags)
augmenter = create_orth_variants_augmenter(
level=0.2, lower=0.5, orth_variants={"single": single}
)
with make_docbin([doc]) as output_file:
with make_docbin([doc] * 10) as output_file:
reader = Corpus(output_file, augmenter=augmenter)
# Due to randomness, only test that it works without errors for now
# Due to randomness, only test that it works without errors
list(reader(nlp))
# check that the following settings lowercase everything
augmenter = create_orth_variants_augmenter(
level=1.0, lower=1.0, orth_variants={"single": single}
)
with make_docbin([doc] * 10) as output_file:
reader = Corpus(output_file, augmenter=augmenter)
for example in reader(nlp):
for token in example.reference:
assert token.text == token.text.lower()
# check that lowercasing is applied without tags
doc = Doc(nlp.vocab, words=words, spaces=[True] * len(words))
augmenter = create_orth_variants_augmenter(
level=1.0, lower=1.0, orth_variants={"single": single}
)
with make_docbin([doc] * 10) as output_file:
reader = Corpus(output_file, augmenter=augmenter)
for example in reader(nlp):
for ex_token, doc_token in zip(example.reference, doc):
assert ex_token.text == doc_token.text.lower()
# check that no lowercasing is applied with lower=0.0
doc = Doc(nlp.vocab, words=words, spaces=[True] * len(words))
augmenter = create_orth_variants_augmenter(
level=1.0, lower=0.0, orth_variants={"single": single}
)
with make_docbin([doc] * 10) as output_file:
reader = Corpus(output_file, augmenter=augmenter)
for example in reader(nlp):
for ex_token, doc_token in zip(example.reference, doc):
assert ex_token.text == doc_token.text
def test_lowercase_augmenter(nlp, doc):
augmenter = create_lower_casing_augmenter(level=1.0)
@ -66,6 +106,21 @@ def test_lowercase_augmenter(nlp, doc):
assert ref_ent.text == orig_ent.text.lower()
assert [t.pos_ for t in eg.reference] == [t.pos_ for t in doc]
# check that augmentation works when lowercasing leads to different
# predicted tokenization
words = ["A", "B", "CCC."]
doc = Doc(nlp.vocab, words=words)
with make_docbin([doc]) as output_file:
reader = Corpus(output_file, augmenter=augmenter)
corpus = list(reader(nlp))
eg = corpus[0]
assert eg.reference.text == doc.text.lower()
assert eg.predicted.text == doc.text.lower()
assert [t.text for t in eg.reference] == [t.lower() for t in words]
assert [t.text for t in eg.predicted] == [
t.text for t in nlp.make_doc(doc.text.lower())
]
@pytest.mark.filterwarnings("ignore::UserWarning")
def test_custom_data_augmentation(nlp, doc):

View File

@ -1,12 +1,10 @@
from typing import Callable, Iterator, Dict, List, Tuple, TYPE_CHECKING
import random
import itertools
import copy
from functools import partial
from pydantic import BaseModel, StrictStr
from ..util import registry
from ..tokens import Doc
from .example import Example
if TYPE_CHECKING:
@ -71,7 +69,7 @@ def lower_casing_augmenter(
else:
example_dict = example.to_dict()
doc = nlp.make_doc(example.text.lower())
example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in doc]
example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in example.reference]
yield example.from_dict(doc, example_dict)
@ -88,24 +86,15 @@ def orth_variants_augmenter(
else:
raw_text = example.text
orig_dict = example.to_dict()
if not orig_dict["token_annotation"]:
yield example
else:
variant_text, variant_token_annot = make_orth_variants(
nlp,
raw_text,
orig_dict["token_annotation"],
orth_variants,
lower=raw_text is not None and random.random() < lower,
)
if variant_text:
doc = nlp.make_doc(variant_text)
else:
doc = Doc(nlp.vocab, words=variant_token_annot["ORTH"])
variant_token_annot["ORTH"] = [w.text for w in doc]
variant_token_annot["SPACY"] = [w.whitespace_ for w in doc]
orig_dict["token_annotation"] = variant_token_annot
yield example.from_dict(doc, orig_dict)
variant_text, variant_token_annot = make_orth_variants(
nlp,
raw_text,
orig_dict["token_annotation"],
orth_variants,
lower=raw_text is not None and random.random() < lower,
)
orig_dict["token_annotation"] = variant_token_annot
yield example.from_dict(nlp.make_doc(variant_text), orig_dict)
def make_orth_variants(
@ -116,88 +105,53 @@ def make_orth_variants(
*,
lower: bool = False,
) -> Tuple[str, Dict[str, List[str]]]:
orig_token_dict = copy.deepcopy(token_dict)
ndsv = orth_variants.get("single", [])
ndpv = orth_variants.get("paired", [])
words = token_dict.get("ORTH", [])
tags = token_dict.get("TAG", [])
# keep unmodified if words or tags are not defined
if words and tags:
if lower:
words = [w.lower() for w in words]
# single variants
punct_choices = [random.choice(x["variants"]) for x in ndsv]
for word_idx in range(len(words)):
for punct_idx in range(len(ndsv)):
if (
tags[word_idx] in ndsv[punct_idx]["tags"]
and words[word_idx] in ndsv[punct_idx]["variants"]
):
words[word_idx] = punct_choices[punct_idx]
# paired variants
punct_choices = [random.choice(x["variants"]) for x in ndpv]
for word_idx in range(len(words)):
for punct_idx in range(len(ndpv)):
if tags[word_idx] in ndpv[punct_idx]["tags"] and words[
word_idx
] in itertools.chain.from_iterable(ndpv[punct_idx]["variants"]):
# backup option: random left vs. right from pair
pair_idx = random.choice([0, 1])
# best option: rely on paired POS tags like `` / ''
if len(ndpv[punct_idx]["tags"]) == 2:
pair_idx = ndpv[punct_idx]["tags"].index(tags[word_idx])
# next best option: rely on position in variants
# (may not be unambiguous, so order of variants matters)
else:
for pair in ndpv[punct_idx]["variants"]:
if words[word_idx] in pair:
pair_idx = pair.index(words[word_idx])
words[word_idx] = punct_choices[punct_idx][pair_idx]
# keep unmodified if words are not defined
if not words:
return raw, token_dict
if lower:
words = [w.lower() for w in words]
raw = raw.lower()
# if no tags, only lowercase
if not tags:
token_dict["ORTH"] = words
token_dict["TAG"] = tags
# modify raw
if raw is not None:
variants = []
for single_variants in ndsv:
variants.extend(single_variants["variants"])
for paired_variants in ndpv:
variants.extend(
list(itertools.chain.from_iterable(paired_variants["variants"]))
)
# store variants in reverse length order to be able to prioritize
# longer matches (e.g., "---" before "--")
variants = sorted(variants, key=lambda x: len(x))
variants.reverse()
variant_raw = ""
raw_idx = 0
# add initial whitespace
while raw_idx < len(raw) and raw[raw_idx].isspace():
variant_raw += raw[raw_idx]
raw_idx += 1
for word in words:
match_found = False
# skip whitespace words
if word.isspace():
match_found = True
# add identical word
elif word not in variants and raw[raw_idx:].startswith(word):
variant_raw += word
raw_idx += len(word)
match_found = True
# add variant word
else:
for variant in variants:
if not match_found and raw[raw_idx:].startswith(variant):
raw_idx += len(variant)
variant_raw += word
match_found = True
# something went wrong, abort
# (add a warning message?)
if not match_found:
return raw, orig_token_dict
# add following whitespace
while raw_idx < len(raw) and raw[raw_idx].isspace():
variant_raw += raw[raw_idx]
raw_idx += 1
raw = variant_raw
return raw, token_dict
# single variants
ndsv = orth_variants.get("single", [])
punct_choices = [random.choice(x["variants"]) for x in ndsv]
for word_idx in range(len(words)):
for punct_idx in range(len(ndsv)):
if (
tags[word_idx] in ndsv[punct_idx]["tags"]
and words[word_idx] in ndsv[punct_idx]["variants"]
):
words[word_idx] = punct_choices[punct_idx]
# paired variants
ndpv = orth_variants.get("paired", [])
punct_choices = [random.choice(x["variants"]) for x in ndpv]
for word_idx in range(len(words)):
for punct_idx in range(len(ndpv)):
if tags[word_idx] in ndpv[punct_idx]["tags"] and words[
word_idx
] in itertools.chain.from_iterable(ndpv[punct_idx]["variants"]):
# backup option: random left vs. right from pair
pair_idx = random.choice([0, 1])
# best option: rely on paired POS tags like `` / ''
if len(ndpv[punct_idx]["tags"]) == 2:
pair_idx = ndpv[punct_idx]["tags"].index(tags[word_idx])
# next best option: rely on position in variants
# (may not be unambiguous, so order of variants matters)
else:
for pair in ndpv[punct_idx]["variants"]:
if words[word_idx] in pair:
pair_idx = pair.index(words[word_idx])
words[word_idx] = punct_choices[punct_idx][pair_idx]
token_dict["ORTH"] = words
# construct modified raw text from words and spaces
raw = ""
for orth, spacy in zip(token_dict["ORTH"], token_dict["SPACY"]):
raw += orth
if spacy:
raw += " "
return raw, token_dict