from typing import Callable, Iterator, Dict, List, Tuple, TYPE_CHECKING import random import itertools from functools import partial from pydantic import BaseModel, StrictStr from ..util import registry from .example import Example if TYPE_CHECKING: from ..language import Language # noqa: F401 class OrthVariantsSingle(BaseModel): tags: List[StrictStr] variants: List[StrictStr] class OrthVariantsPaired(BaseModel): tags: List[StrictStr] variants: List[List[StrictStr]] class OrthVariants(BaseModel): paired: List[OrthVariantsPaired] = {} single: List[OrthVariantsSingle] = {} @registry.augmenters("spacy.orth_variants.v1") def create_orth_variants_augmenter( level: float, lower: float, orth_variants: OrthVariants ) -> Callable[["Language", Example], Iterator[Example]]: """Create a data augmentation callback that uses orth-variant replacement. The callback can be added to a corpus or other data iterator during training. level (float): The percentage of texts that will be augmented. lower (float): The percentage of texts that will be lowercased. orth_variants (Dict[str, dict]): A dictionary containing the single and paired orth variants. Typically loaded from a JSON file. RETURNS (Callable[[Language, Example], Iterator[Example]]): The augmenter. """ return partial( orth_variants_augmenter, orth_variants=orth_variants, level=level, lower=lower ) @registry.augmenters("spacy.lower_case.v1") def create_lower_casing_augmenter( level: float, ) -> Callable[["Language", Example], Iterator[Example]]: """Create a data augmentation callback that converts documents to lowercase. The callback can be added to a corpus or other data iterator during training. level (float): The percentage of texts that will be augmented. RETURNS (Callable[[Language, Example], Iterator[Example]]): The augmenter. """ return partial(lower_casing_augmenter, level=level) def dont_augment(nlp: "Language", example: Example) -> Iterator[Example]: yield example def lower_casing_augmenter( nlp: "Language", example: Example, *, level: float ) -> Iterator[Example]: if random.random() >= level: yield example else: example_dict = example.to_dict() doc = nlp.make_doc(example.text.lower()) example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in example.reference] yield example.from_dict(doc, example_dict) def orth_variants_augmenter( nlp: "Language", example: Example, orth_variants: dict, *, level: float = 0.0, lower: float = 0.0, ) -> Iterator[Example]: if random.random() >= level: yield example else: raw_text = example.text orig_dict = example.to_dict() variant_text, variant_token_annot = make_orth_variants( nlp, raw_text, orig_dict["token_annotation"], orth_variants, lower=raw_text is not None and random.random() < lower, ) orig_dict["token_annotation"] = variant_token_annot yield example.from_dict(nlp.make_doc(variant_text), orig_dict) def make_orth_variants( nlp: "Language", raw: str, token_dict: Dict[str, List[str]], orth_variants: Dict[str, List[Dict[str, List[str]]]], *, lower: bool = False, ) -> Tuple[str, Dict[str, List[str]]]: words = token_dict.get("ORTH", []) tags = token_dict.get("TAG", []) # keep unmodified if words are not defined if not words: return raw, token_dict if lower: words = [w.lower() for w in words] raw = raw.lower() # if no tags, only lowercase if not tags: token_dict["ORTH"] = words return raw, token_dict # single variants ndsv = orth_variants.get("single", []) punct_choices = [random.choice(x["variants"]) for x in ndsv] for word_idx in range(len(words)): for punct_idx in range(len(ndsv)): if ( tags[word_idx] in ndsv[punct_idx]["tags"] and words[word_idx] in ndsv[punct_idx]["variants"] ): words[word_idx] = punct_choices[punct_idx] # paired variants ndpv = orth_variants.get("paired", []) punct_choices = [random.choice(x["variants"]) for x in ndpv] for word_idx in range(len(words)): for punct_idx in range(len(ndpv)): if tags[word_idx] in ndpv[punct_idx]["tags"] and words[ word_idx ] in itertools.chain.from_iterable(ndpv[punct_idx]["variants"]): # backup option: random left vs. right from pair pair_idx = random.choice([0, 1]) # best option: rely on paired POS tags like `` / '' if len(ndpv[punct_idx]["tags"]) == 2: pair_idx = ndpv[punct_idx]["tags"].index(tags[word_idx]) # next best option: rely on position in variants # (may not be unambiguous, so order of variants matters) else: for pair in ndpv[punct_idx]["variants"]: if words[word_idx] in pair: pair_idx = pair.index(words[word_idx]) words[word_idx] = punct_choices[punct_idx][pair_idx] token_dict["ORTH"] = words # construct modified raw text from words and spaces raw = "" for orth, spacy in zip(token_dict["ORTH"], token_dict["SPACY"]): raw += orth if spacy: raw += " " return raw, token_dict