Fix lowercase augmentation (#7336)

* Fix aborted/skipped augmentation for `spacy.orth_variants.v1` if
lowercasing was enabled for an example
* Simplify `spacy.orth_variants.v1` for `Example` vs. `GoldParse`
* Preserve reference tokenization in `spacy.lower_case.v1`
This commit is contained in:
Adriane Boyd 2021-03-09 04:02:32 +01:00 committed by GitHub
parent cd70c3cb79
commit 3f3e8110dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 114 additions and 105 deletions

View File

@ -38,19 +38,59 @@ def doc(nlp):
@pytest.mark.filterwarnings("ignore::UserWarning") @pytest.mark.filterwarnings("ignore::UserWarning")
def test_make_orth_variants(nlp, doc): def test_make_orth_variants(nlp):
single = [ single = [
{"tags": ["NFP"], "variants": ["", "..."]}, {"tags": ["NFP"], "variants": ["", "..."]},
{"tags": [":"], "variants": ["-", "", "", "--", "---", "——"]}, {"tags": [":"], "variants": ["-", "", "", "--", "---", "——"]},
] ]
# fmt: off
words = ["\n\n", "A", "\t", "B", "a", "b", "", "...", "-", "", "", "--", "---", "——"]
tags = ["_SP", "NN", "\t", "NN", "NN", "NN", "NFP", "NFP", ":", ":", ":", ":", ":", ":"]
# fmt: on
spaces = [True] * len(words)
spaces[0] = False
spaces[2] = False
doc = Doc(nlp.vocab, words=words, spaces=spaces, tags=tags)
augmenter = create_orth_variants_augmenter( augmenter = create_orth_variants_augmenter(
level=0.2, lower=0.5, orth_variants={"single": single} level=0.2, lower=0.5, orth_variants={"single": single}
) )
with make_docbin([doc]) as output_file: with make_docbin([doc] * 10) as output_file:
reader = Corpus(output_file, augmenter=augmenter) reader = Corpus(output_file, augmenter=augmenter)
# Due to randomness, only test that it works without errors for now # Due to randomness, only test that it works without errors
list(reader(nlp)) list(reader(nlp))
# check that the following settings lowercase everything
augmenter = create_orth_variants_augmenter(
level=1.0, lower=1.0, orth_variants={"single": single}
)
with make_docbin([doc] * 10) as output_file:
reader = Corpus(output_file, augmenter=augmenter)
for example in reader(nlp):
for token in example.reference:
assert token.text == token.text.lower()
# check that lowercasing is applied without tags
doc = Doc(nlp.vocab, words=words, spaces=[True] * len(words))
augmenter = create_orth_variants_augmenter(
level=1.0, lower=1.0, orth_variants={"single": single}
)
with make_docbin([doc] * 10) as output_file:
reader = Corpus(output_file, augmenter=augmenter)
for example in reader(nlp):
for ex_token, doc_token in zip(example.reference, doc):
assert ex_token.text == doc_token.text.lower()
# check that no lowercasing is applied with lower=0.0
doc = Doc(nlp.vocab, words=words, spaces=[True] * len(words))
augmenter = create_orth_variants_augmenter(
level=1.0, lower=0.0, orth_variants={"single": single}
)
with make_docbin([doc] * 10) as output_file:
reader = Corpus(output_file, augmenter=augmenter)
for example in reader(nlp):
for ex_token, doc_token in zip(example.reference, doc):
assert ex_token.text == doc_token.text
def test_lowercase_augmenter(nlp, doc): def test_lowercase_augmenter(nlp, doc):
augmenter = create_lower_casing_augmenter(level=1.0) augmenter = create_lower_casing_augmenter(level=1.0)
@ -66,6 +106,21 @@ def test_lowercase_augmenter(nlp, doc):
assert ref_ent.text == orig_ent.text.lower() assert ref_ent.text == orig_ent.text.lower()
assert [t.pos_ for t in eg.reference] == [t.pos_ for t in doc] assert [t.pos_ for t in eg.reference] == [t.pos_ for t in doc]
# check that augmentation works when lowercasing leads to different
# predicted tokenization
words = ["A", "B", "CCC."]
doc = Doc(nlp.vocab, words=words)
with make_docbin([doc]) as output_file:
reader = Corpus(output_file, augmenter=augmenter)
corpus = list(reader(nlp))
eg = corpus[0]
assert eg.reference.text == doc.text.lower()
assert eg.predicted.text == doc.text.lower()
assert [t.text for t in eg.reference] == [t.lower() for t in words]
assert [t.text for t in eg.predicted] == [
t.text for t in nlp.make_doc(doc.text.lower())
]
@pytest.mark.filterwarnings("ignore::UserWarning") @pytest.mark.filterwarnings("ignore::UserWarning")
def test_custom_data_augmentation(nlp, doc): def test_custom_data_augmentation(nlp, doc):

View File

@ -1,12 +1,10 @@
from typing import Callable, Iterator, Dict, List, Tuple, TYPE_CHECKING from typing import Callable, Iterator, Dict, List, Tuple, TYPE_CHECKING
import random import random
import itertools import itertools
import copy
from functools import partial from functools import partial
from pydantic import BaseModel, StrictStr from pydantic import BaseModel, StrictStr
from ..util import registry from ..util import registry
from ..tokens import Doc
from .example import Example from .example import Example
if TYPE_CHECKING: if TYPE_CHECKING:
@ -71,7 +69,7 @@ def lower_casing_augmenter(
else: else:
example_dict = example.to_dict() example_dict = example.to_dict()
doc = nlp.make_doc(example.text.lower()) doc = nlp.make_doc(example.text.lower())
example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in doc] example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in example.reference]
yield example.from_dict(doc, example_dict) yield example.from_dict(doc, example_dict)
@ -88,9 +86,6 @@ def orth_variants_augmenter(
else: else:
raw_text = example.text raw_text = example.text
orig_dict = example.to_dict() orig_dict = example.to_dict()
if not orig_dict["token_annotation"]:
yield example
else:
variant_text, variant_token_annot = make_orth_variants( variant_text, variant_token_annot = make_orth_variants(
nlp, nlp,
raw_text, raw_text,
@ -98,14 +93,8 @@ def orth_variants_augmenter(
orth_variants, orth_variants,
lower=raw_text is not None and random.random() < lower, lower=raw_text is not None and random.random() < lower,
) )
if variant_text:
doc = nlp.make_doc(variant_text)
else:
doc = Doc(nlp.vocab, words=variant_token_annot["ORTH"])
variant_token_annot["ORTH"] = [w.text for w in doc]
variant_token_annot["SPACY"] = [w.whitespace_ for w in doc]
orig_dict["token_annotation"] = variant_token_annot orig_dict["token_annotation"] = variant_token_annot
yield example.from_dict(doc, orig_dict) yield example.from_dict(nlp.make_doc(variant_text), orig_dict)
def make_orth_variants( def make_orth_variants(
@ -116,16 +105,20 @@ def make_orth_variants(
*, *,
lower: bool = False, lower: bool = False,
) -> Tuple[str, Dict[str, List[str]]]: ) -> Tuple[str, Dict[str, List[str]]]:
orig_token_dict = copy.deepcopy(token_dict)
ndsv = orth_variants.get("single", [])
ndpv = orth_variants.get("paired", [])
words = token_dict.get("ORTH", []) words = token_dict.get("ORTH", [])
tags = token_dict.get("TAG", []) tags = token_dict.get("TAG", [])
# keep unmodified if words or tags are not defined # keep unmodified if words are not defined
if words and tags: if not words:
return raw, token_dict
if lower: if lower:
words = [w.lower() for w in words] words = [w.lower() for w in words]
raw = raw.lower()
# if no tags, only lowercase
if not tags:
token_dict["ORTH"] = words
return raw, token_dict
# single variants # single variants
ndsv = orth_variants.get("single", [])
punct_choices = [random.choice(x["variants"]) for x in ndsv] punct_choices = [random.choice(x["variants"]) for x in ndsv]
for word_idx in range(len(words)): for word_idx in range(len(words)):
for punct_idx in range(len(ndsv)): for punct_idx in range(len(ndsv)):
@ -135,6 +128,7 @@ def make_orth_variants(
): ):
words[word_idx] = punct_choices[punct_idx] words[word_idx] = punct_choices[punct_idx]
# paired variants # paired variants
ndpv = orth_variants.get("paired", [])
punct_choices = [random.choice(x["variants"]) for x in ndpv] punct_choices = [random.choice(x["variants"]) for x in ndpv]
for word_idx in range(len(words)): for word_idx in range(len(words)):
for punct_idx in range(len(ndpv)): for punct_idx in range(len(ndpv)):
@ -154,50 +148,10 @@ def make_orth_variants(
pair_idx = pair.index(words[word_idx]) pair_idx = pair.index(words[word_idx])
words[word_idx] = punct_choices[punct_idx][pair_idx] words[word_idx] = punct_choices[punct_idx][pair_idx]
token_dict["ORTH"] = words token_dict["ORTH"] = words
token_dict["TAG"] = tags # construct modified raw text from words and spaces
# modify raw raw = ""
if raw is not None: for orth, spacy in zip(token_dict["ORTH"], token_dict["SPACY"]):
variants = [] raw += orth
for single_variants in ndsv: if spacy:
variants.extend(single_variants["variants"]) raw += " "
for paired_variants in ndpv:
variants.extend(
list(itertools.chain.from_iterable(paired_variants["variants"]))
)
# store variants in reverse length order to be able to prioritize
# longer matches (e.g., "---" before "--")
variants = sorted(variants, key=lambda x: len(x))
variants.reverse()
variant_raw = ""
raw_idx = 0
# add initial whitespace
while raw_idx < len(raw) and raw[raw_idx].isspace():
variant_raw += raw[raw_idx]
raw_idx += 1
for word in words:
match_found = False
# skip whitespace words
if word.isspace():
match_found = True
# add identical word
elif word not in variants and raw[raw_idx:].startswith(word):
variant_raw += word
raw_idx += len(word)
match_found = True
# add variant word
else:
for variant in variants:
if not match_found and raw[raw_idx:].startswith(variant):
raw_idx += len(variant)
variant_raw += word
match_found = True
# something went wrong, abort
# (add a warning message?)
if not match_found:
return raw, orig_token_dict
# add following whitespace
while raw_idx < len(raw) and raw[raw_idx].isspace():
variant_raw += raw[raw_idx]
raw_idx += 1
raw = variant_raw
return raw, token_dict return raw, token_dict