diff --git a/spacy/tests/training/test_augmenters.py b/spacy/tests/training/test_augmenters.py index 0bd4d5ef2..43a78e4b0 100644 --- a/spacy/tests/training/test_augmenters.py +++ b/spacy/tests/training/test_augmenters.py @@ -38,19 +38,59 @@ def doc(nlp): @pytest.mark.filterwarnings("ignore::UserWarning") -def test_make_orth_variants(nlp, doc): +def test_make_orth_variants(nlp): single = [ {"tags": ["NFP"], "variants": ["…", "..."]}, {"tags": [":"], "variants": ["-", "—", "–", "--", "---", "——"]}, ] + # fmt: off + words = ["\n\n", "A", "\t", "B", "a", "b", "…", "...", "-", "—", "–", "--", "---", "——"] + tags = ["_SP", "NN", "\t", "NN", "NN", "NN", "NFP", "NFP", ":", ":", ":", ":", ":", ":"] + # fmt: on + spaces = [True] * len(words) + spaces[0] = False + spaces[2] = False + doc = Doc(nlp.vocab, words=words, spaces=spaces, tags=tags) augmenter = create_orth_variants_augmenter( level=0.2, lower=0.5, orth_variants={"single": single} ) - with make_docbin([doc]) as output_file: + with make_docbin([doc] * 10) as output_file: reader = Corpus(output_file, augmenter=augmenter) - # Due to randomness, only test that it works without errors for now + # Due to randomness, only test that it works without errors list(reader(nlp)) + # check that the following settings lowercase everything + augmenter = create_orth_variants_augmenter( + level=1.0, lower=1.0, orth_variants={"single": single} + ) + with make_docbin([doc] * 10) as output_file: + reader = Corpus(output_file, augmenter=augmenter) + for example in reader(nlp): + for token in example.reference: + assert token.text == token.text.lower() + + # check that lowercasing is applied without tags + doc = Doc(nlp.vocab, words=words, spaces=[True] * len(words)) + augmenter = create_orth_variants_augmenter( + level=1.0, lower=1.0, orth_variants={"single": single} + ) + with make_docbin([doc] * 10) as output_file: + reader = Corpus(output_file, augmenter=augmenter) + for example in reader(nlp): + for ex_token, doc_token in zip(example.reference, doc): + assert ex_token.text == doc_token.text.lower() + + # check that no lowercasing is applied with lower=0.0 + doc = Doc(nlp.vocab, words=words, spaces=[True] * len(words)) + augmenter = create_orth_variants_augmenter( + level=1.0, lower=0.0, orth_variants={"single": single} + ) + with make_docbin([doc] * 10) as output_file: + reader = Corpus(output_file, augmenter=augmenter) + for example in reader(nlp): + for ex_token, doc_token in zip(example.reference, doc): + assert ex_token.text == doc_token.text + def test_lowercase_augmenter(nlp, doc): augmenter = create_lower_casing_augmenter(level=1.0) @@ -66,6 +106,21 @@ def test_lowercase_augmenter(nlp, doc): assert ref_ent.text == orig_ent.text.lower() assert [t.pos_ for t in eg.reference] == [t.pos_ for t in doc] + # check that augmentation works when lowercasing leads to different + # predicted tokenization + words = ["A", "B", "CCC."] + doc = Doc(nlp.vocab, words=words) + with make_docbin([doc]) as output_file: + reader = Corpus(output_file, augmenter=augmenter) + corpus = list(reader(nlp)) + eg = corpus[0] + assert eg.reference.text == doc.text.lower() + assert eg.predicted.text == doc.text.lower() + assert [t.text for t in eg.reference] == [t.lower() for t in words] + assert [t.text for t in eg.predicted] == [ + t.text for t in nlp.make_doc(doc.text.lower()) + ] + @pytest.mark.filterwarnings("ignore::UserWarning") def test_custom_data_augmentation(nlp, doc): diff --git a/spacy/training/augment.py b/spacy/training/augment.py index 13ae45bd2..0dae92143 100644 --- a/spacy/training/augment.py +++ b/spacy/training/augment.py @@ -1,12 +1,10 @@ from typing import Callable, Iterator, Dict, List, Tuple, TYPE_CHECKING import random import itertools -import copy from functools import partial from pydantic import BaseModel, StrictStr from ..util import registry -from ..tokens import Doc from .example import Example if TYPE_CHECKING: @@ -71,7 +69,7 @@ def lower_casing_augmenter( else: example_dict = example.to_dict() doc = nlp.make_doc(example.text.lower()) - example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in doc] + example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in example.reference] yield example.from_dict(doc, example_dict) @@ -88,24 +86,15 @@ def orth_variants_augmenter( else: raw_text = example.text orig_dict = example.to_dict() - if not orig_dict["token_annotation"]: - yield example - else: - variant_text, variant_token_annot = make_orth_variants( - nlp, - raw_text, - orig_dict["token_annotation"], - orth_variants, - lower=raw_text is not None and random.random() < lower, - ) - if variant_text: - doc = nlp.make_doc(variant_text) - else: - doc = Doc(nlp.vocab, words=variant_token_annot["ORTH"]) - variant_token_annot["ORTH"] = [w.text for w in doc] - variant_token_annot["SPACY"] = [w.whitespace_ for w in doc] - orig_dict["token_annotation"] = variant_token_annot - yield example.from_dict(doc, orig_dict) + variant_text, variant_token_annot = make_orth_variants( + nlp, + raw_text, + orig_dict["token_annotation"], + orth_variants, + lower=raw_text is not None and random.random() < lower, + ) + orig_dict["token_annotation"] = variant_token_annot + yield example.from_dict(nlp.make_doc(variant_text), orig_dict) def make_orth_variants( @@ -116,88 +105,53 @@ def make_orth_variants( *, lower: bool = False, ) -> Tuple[str, Dict[str, List[str]]]: - orig_token_dict = copy.deepcopy(token_dict) - ndsv = orth_variants.get("single", []) - ndpv = orth_variants.get("paired", []) words = token_dict.get("ORTH", []) tags = token_dict.get("TAG", []) - # keep unmodified if words or tags are not defined - if words and tags: - if lower: - words = [w.lower() for w in words] - # single variants - punct_choices = [random.choice(x["variants"]) for x in ndsv] - for word_idx in range(len(words)): - for punct_idx in range(len(ndsv)): - if ( - tags[word_idx] in ndsv[punct_idx]["tags"] - and words[word_idx] in ndsv[punct_idx]["variants"] - ): - words[word_idx] = punct_choices[punct_idx] - # paired variants - punct_choices = [random.choice(x["variants"]) for x in ndpv] - for word_idx in range(len(words)): - for punct_idx in range(len(ndpv)): - if tags[word_idx] in ndpv[punct_idx]["tags"] and words[ - word_idx - ] in itertools.chain.from_iterable(ndpv[punct_idx]["variants"]): - # backup option: random left vs. right from pair - pair_idx = random.choice([0, 1]) - # best option: rely on paired POS tags like `` / '' - if len(ndpv[punct_idx]["tags"]) == 2: - pair_idx = ndpv[punct_idx]["tags"].index(tags[word_idx]) - # next best option: rely on position in variants - # (may not be unambiguous, so order of variants matters) - else: - for pair in ndpv[punct_idx]["variants"]: - if words[word_idx] in pair: - pair_idx = pair.index(words[word_idx]) - words[word_idx] = punct_choices[punct_idx][pair_idx] + # keep unmodified if words are not defined + if not words: + return raw, token_dict + if lower: + words = [w.lower() for w in words] + raw = raw.lower() + # if no tags, only lowercase + if not tags: token_dict["ORTH"] = words - token_dict["TAG"] = tags - # modify raw - if raw is not None: - variants = [] - for single_variants in ndsv: - variants.extend(single_variants["variants"]) - for paired_variants in ndpv: - variants.extend( - list(itertools.chain.from_iterable(paired_variants["variants"])) - ) - # store variants in reverse length order to be able to prioritize - # longer matches (e.g., "---" before "--") - variants = sorted(variants, key=lambda x: len(x)) - variants.reverse() - variant_raw = "" - raw_idx = 0 - # add initial whitespace - while raw_idx < len(raw) and raw[raw_idx].isspace(): - variant_raw += raw[raw_idx] - raw_idx += 1 - for word in words: - match_found = False - # skip whitespace words - if word.isspace(): - match_found = True - # add identical word - elif word not in variants and raw[raw_idx:].startswith(word): - variant_raw += word - raw_idx += len(word) - match_found = True - # add variant word - else: - for variant in variants: - if not match_found and raw[raw_idx:].startswith(variant): - raw_idx += len(variant) - variant_raw += word - match_found = True - # something went wrong, abort - # (add a warning message?) - if not match_found: - return raw, orig_token_dict - # add following whitespace - while raw_idx < len(raw) and raw[raw_idx].isspace(): - variant_raw += raw[raw_idx] - raw_idx += 1 - raw = variant_raw + return raw, token_dict + # single variants + ndsv = orth_variants.get("single", []) + punct_choices = [random.choice(x["variants"]) for x in ndsv] + for word_idx in range(len(words)): + for punct_idx in range(len(ndsv)): + if ( + tags[word_idx] in ndsv[punct_idx]["tags"] + and words[word_idx] in ndsv[punct_idx]["variants"] + ): + words[word_idx] = punct_choices[punct_idx] + # paired variants + ndpv = orth_variants.get("paired", []) + punct_choices = [random.choice(x["variants"]) for x in ndpv] + for word_idx in range(len(words)): + for punct_idx in range(len(ndpv)): + if tags[word_idx] in ndpv[punct_idx]["tags"] and words[ + word_idx + ] in itertools.chain.from_iterable(ndpv[punct_idx]["variants"]): + # backup option: random left vs. right from pair + pair_idx = random.choice([0, 1]) + # best option: rely on paired POS tags like `` / '' + if len(ndpv[punct_idx]["tags"]) == 2: + pair_idx = ndpv[punct_idx]["tags"].index(tags[word_idx]) + # next best option: rely on position in variants + # (may not be unambiguous, so order of variants matters) + else: + for pair in ndpv[punct_idx]["variants"]: + if words[word_idx] in pair: + pair_idx = pair.index(words[word_idx]) + words[word_idx] = punct_choices[punct_idx][pair_idx] + token_dict["ORTH"] = words + # construct modified raw text from words and spaces + raw = "" + for orth, spacy in zip(token_dict["ORTH"], token_dict["SPACY"]): + raw += orth + if spacy: + raw += " " return raw, token_dict