Fix lowercase augmentation (#7336)

* Fix aborted/skipped augmentation for `spacy.orth_variants.v1` if
lowercasing was enabled for an example
* Simplify `spacy.orth_variants.v1` for `Example` vs. `GoldParse`
* Preserve reference tokenization in `spacy.lower_case.v1`
This commit is contained in:
Adriane Boyd 2021-03-09 04:02:32 +01:00 committed by GitHub
parent cd70c3cb79
commit 3f3e8110dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 114 additions and 105 deletions

View File

@ -38,19 +38,59 @@ def doc(nlp):
@pytest.mark.filterwarnings("ignore::UserWarning") @pytest.mark.filterwarnings("ignore::UserWarning")
def test_make_orth_variants(nlp, doc): def test_make_orth_variants(nlp):
single = [ single = [
{"tags": ["NFP"], "variants": ["", "..."]}, {"tags": ["NFP"], "variants": ["", "..."]},
{"tags": [":"], "variants": ["-", "", "", "--", "---", "——"]}, {"tags": [":"], "variants": ["-", "", "", "--", "---", "——"]},
] ]
# fmt: off
words = ["\n\n", "A", "\t", "B", "a", "b", "", "...", "-", "", "", "--", "---", "——"]
tags = ["_SP", "NN", "\t", "NN", "NN", "NN", "NFP", "NFP", ":", ":", ":", ":", ":", ":"]
# fmt: on
spaces = [True] * len(words)
spaces[0] = False
spaces[2] = False
doc = Doc(nlp.vocab, words=words, spaces=spaces, tags=tags)
augmenter = create_orth_variants_augmenter( augmenter = create_orth_variants_augmenter(
level=0.2, lower=0.5, orth_variants={"single": single} level=0.2, lower=0.5, orth_variants={"single": single}
) )
with make_docbin([doc]) as output_file: with make_docbin([doc] * 10) as output_file:
reader = Corpus(output_file, augmenter=augmenter) reader = Corpus(output_file, augmenter=augmenter)
# Due to randomness, only test that it works without errors for now # Due to randomness, only test that it works without errors
list(reader(nlp)) list(reader(nlp))
# check that the following settings lowercase everything
augmenter = create_orth_variants_augmenter(
level=1.0, lower=1.0, orth_variants={"single": single}
)
with make_docbin([doc] * 10) as output_file:
reader = Corpus(output_file, augmenter=augmenter)
for example in reader(nlp):
for token in example.reference:
assert token.text == token.text.lower()
# check that lowercasing is applied without tags
doc = Doc(nlp.vocab, words=words, spaces=[True] * len(words))
augmenter = create_orth_variants_augmenter(
level=1.0, lower=1.0, orth_variants={"single": single}
)
with make_docbin([doc] * 10) as output_file:
reader = Corpus(output_file, augmenter=augmenter)
for example in reader(nlp):
for ex_token, doc_token in zip(example.reference, doc):
assert ex_token.text == doc_token.text.lower()
# check that no lowercasing is applied with lower=0.0
doc = Doc(nlp.vocab, words=words, spaces=[True] * len(words))
augmenter = create_orth_variants_augmenter(
level=1.0, lower=0.0, orth_variants={"single": single}
)
with make_docbin([doc] * 10) as output_file:
reader = Corpus(output_file, augmenter=augmenter)
for example in reader(nlp):
for ex_token, doc_token in zip(example.reference, doc):
assert ex_token.text == doc_token.text
def test_lowercase_augmenter(nlp, doc): def test_lowercase_augmenter(nlp, doc):
augmenter = create_lower_casing_augmenter(level=1.0) augmenter = create_lower_casing_augmenter(level=1.0)
@ -66,6 +106,21 @@ def test_lowercase_augmenter(nlp, doc):
assert ref_ent.text == orig_ent.text.lower() assert ref_ent.text == orig_ent.text.lower()
assert [t.pos_ for t in eg.reference] == [t.pos_ for t in doc] assert [t.pos_ for t in eg.reference] == [t.pos_ for t in doc]
# check that augmentation works when lowercasing leads to different
# predicted tokenization
words = ["A", "B", "CCC."]
doc = Doc(nlp.vocab, words=words)
with make_docbin([doc]) as output_file:
reader = Corpus(output_file, augmenter=augmenter)
corpus = list(reader(nlp))
eg = corpus[0]
assert eg.reference.text == doc.text.lower()
assert eg.predicted.text == doc.text.lower()
assert [t.text for t in eg.reference] == [t.lower() for t in words]
assert [t.text for t in eg.predicted] == [
t.text for t in nlp.make_doc(doc.text.lower())
]
@pytest.mark.filterwarnings("ignore::UserWarning") @pytest.mark.filterwarnings("ignore::UserWarning")
def test_custom_data_augmentation(nlp, doc): def test_custom_data_augmentation(nlp, doc):

View File

@ -1,12 +1,10 @@
from typing import Callable, Iterator, Dict, List, Tuple, TYPE_CHECKING from typing import Callable, Iterator, Dict, List, Tuple, TYPE_CHECKING
import random import random
import itertools import itertools
import copy
from functools import partial from functools import partial
from pydantic import BaseModel, StrictStr from pydantic import BaseModel, StrictStr
from ..util import registry from ..util import registry
from ..tokens import Doc
from .example import Example from .example import Example
if TYPE_CHECKING: if TYPE_CHECKING:
@ -71,7 +69,7 @@ def lower_casing_augmenter(
else: else:
example_dict = example.to_dict() example_dict = example.to_dict()
doc = nlp.make_doc(example.text.lower()) doc = nlp.make_doc(example.text.lower())
example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in doc] example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in example.reference]
yield example.from_dict(doc, example_dict) yield example.from_dict(doc, example_dict)
@ -88,24 +86,15 @@ def orth_variants_augmenter(
else: else:
raw_text = example.text raw_text = example.text
orig_dict = example.to_dict() orig_dict = example.to_dict()
if not orig_dict["token_annotation"]: variant_text, variant_token_annot = make_orth_variants(
yield example nlp,
else: raw_text,
variant_text, variant_token_annot = make_orth_variants( orig_dict["token_annotation"],
nlp, orth_variants,
raw_text, lower=raw_text is not None and random.random() < lower,
orig_dict["token_annotation"], )
orth_variants, orig_dict["token_annotation"] = variant_token_annot
lower=raw_text is not None and random.random() < lower, yield example.from_dict(nlp.make_doc(variant_text), orig_dict)
)
if variant_text:
doc = nlp.make_doc(variant_text)
else:
doc = Doc(nlp.vocab, words=variant_token_annot["ORTH"])
variant_token_annot["ORTH"] = [w.text for w in doc]
variant_token_annot["SPACY"] = [w.whitespace_ for w in doc]
orig_dict["token_annotation"] = variant_token_annot
yield example.from_dict(doc, orig_dict)
def make_orth_variants( def make_orth_variants(
@ -116,88 +105,53 @@ def make_orth_variants(
*, *,
lower: bool = False, lower: bool = False,
) -> Tuple[str, Dict[str, List[str]]]: ) -> Tuple[str, Dict[str, List[str]]]:
orig_token_dict = copy.deepcopy(token_dict)
ndsv = orth_variants.get("single", [])
ndpv = orth_variants.get("paired", [])
words = token_dict.get("ORTH", []) words = token_dict.get("ORTH", [])
tags = token_dict.get("TAG", []) tags = token_dict.get("TAG", [])
# keep unmodified if words or tags are not defined # keep unmodified if words are not defined
if words and tags: if not words:
if lower: return raw, token_dict
words = [w.lower() for w in words] if lower:
# single variants words = [w.lower() for w in words]
punct_choices = [random.choice(x["variants"]) for x in ndsv] raw = raw.lower()
for word_idx in range(len(words)): # if no tags, only lowercase
for punct_idx in range(len(ndsv)): if not tags:
if (
tags[word_idx] in ndsv[punct_idx]["tags"]
and words[word_idx] in ndsv[punct_idx]["variants"]
):
words[word_idx] = punct_choices[punct_idx]
# paired variants
punct_choices = [random.choice(x["variants"]) for x in ndpv]
for word_idx in range(len(words)):
for punct_idx in range(len(ndpv)):
if tags[word_idx] in ndpv[punct_idx]["tags"] and words[
word_idx
] in itertools.chain.from_iterable(ndpv[punct_idx]["variants"]):
# backup option: random left vs. right from pair
pair_idx = random.choice([0, 1])
# best option: rely on paired POS tags like `` / ''
if len(ndpv[punct_idx]["tags"]) == 2:
pair_idx = ndpv[punct_idx]["tags"].index(tags[word_idx])
# next best option: rely on position in variants
# (may not be unambiguous, so order of variants matters)
else:
for pair in ndpv[punct_idx]["variants"]:
if words[word_idx] in pair:
pair_idx = pair.index(words[word_idx])
words[word_idx] = punct_choices[punct_idx][pair_idx]
token_dict["ORTH"] = words token_dict["ORTH"] = words
token_dict["TAG"] = tags return raw, token_dict
# modify raw # single variants
if raw is not None: ndsv = orth_variants.get("single", [])
variants = [] punct_choices = [random.choice(x["variants"]) for x in ndsv]
for single_variants in ndsv: for word_idx in range(len(words)):
variants.extend(single_variants["variants"]) for punct_idx in range(len(ndsv)):
for paired_variants in ndpv: if (
variants.extend( tags[word_idx] in ndsv[punct_idx]["tags"]
list(itertools.chain.from_iterable(paired_variants["variants"])) and words[word_idx] in ndsv[punct_idx]["variants"]
) ):
# store variants in reverse length order to be able to prioritize words[word_idx] = punct_choices[punct_idx]
# longer matches (e.g., "---" before "--") # paired variants
variants = sorted(variants, key=lambda x: len(x)) ndpv = orth_variants.get("paired", [])
variants.reverse() punct_choices = [random.choice(x["variants"]) for x in ndpv]
variant_raw = "" for word_idx in range(len(words)):
raw_idx = 0 for punct_idx in range(len(ndpv)):
# add initial whitespace if tags[word_idx] in ndpv[punct_idx]["tags"] and words[
while raw_idx < len(raw) and raw[raw_idx].isspace(): word_idx
variant_raw += raw[raw_idx] ] in itertools.chain.from_iterable(ndpv[punct_idx]["variants"]):
raw_idx += 1 # backup option: random left vs. right from pair
for word in words: pair_idx = random.choice([0, 1])
match_found = False # best option: rely on paired POS tags like `` / ''
# skip whitespace words if len(ndpv[punct_idx]["tags"]) == 2:
if word.isspace(): pair_idx = ndpv[punct_idx]["tags"].index(tags[word_idx])
match_found = True # next best option: rely on position in variants
# add identical word # (may not be unambiguous, so order of variants matters)
elif word not in variants and raw[raw_idx:].startswith(word): else:
variant_raw += word for pair in ndpv[punct_idx]["variants"]:
raw_idx += len(word) if words[word_idx] in pair:
match_found = True pair_idx = pair.index(words[word_idx])
# add variant word words[word_idx] = punct_choices[punct_idx][pair_idx]
else: token_dict["ORTH"] = words
for variant in variants: # construct modified raw text from words and spaces
if not match_found and raw[raw_idx:].startswith(variant): raw = ""
raw_idx += len(variant) for orth, spacy in zip(token_dict["ORTH"], token_dict["SPACY"]):
variant_raw += word raw += orth
match_found = True if spacy:
# something went wrong, abort raw += " "
# (add a warning message?)
if not match_found:
return raw, orig_token_dict
# add following whitespace
while raw_idx < len(raw) and raw[raw_idx].isspace():
variant_raw += raw[raw_idx]
raw_idx += 1
raw = variant_raw
return raw, token_dict return raw, token_dict