diff --git a/spacy/tests/training/test_augmenters.py b/spacy/tests/training/test_augmenters.py index 43a78e4b0..e3639c5da 100644 --- a/spacy/tests/training/test_augmenters.py +++ b/spacy/tests/training/test_augmenters.py @@ -1,9 +1,11 @@ import pytest -from spacy.training import Corpus +from spacy.pipeline._parser_internals.nonproj import contains_cycle +from spacy.training import Corpus, Example from spacy.training.augment import create_orth_variants_augmenter from spacy.training.augment import create_lower_casing_augmenter +from spacy.training.augment import make_whitespace_variant from spacy.lang.en import English -from spacy.tokens import DocBin, Doc +from spacy.tokens import DocBin, Doc, Span from contextlib import contextmanager import random @@ -153,3 +155,84 @@ def test_custom_data_augmentation(nlp, doc): ents = [(e.start, e.end, e.label) for e in doc.ents] assert [(e.start, e.end, e.label) for e in corpus[0].reference.ents] == ents assert [(e.start, e.end, e.label) for e in corpus[1].reference.ents] == ents + + +def test_make_whitespace_variant(nlp): + # fmt: off + text = "They flew to New York City.\nThen they drove to Washington, D.C." + words = ["They", "flew", "to", "New", "York", "City", ".", "\n", "Then", "they", "drove", "to", "Washington", ",", "D.C."] + spaces = [True, True, True, True, True, False, False, False, True, True, True, True, False, True, False] + tags = ["PRP", "VBD", "IN", "NNP", "NNP", "NNP", ".", "_SP", "RB", "PRP", "VBD", "IN", "NNP", ",", "NNP"] + lemmas = ["they", "fly", "to", "New", "York", "City", ".", "\n", "then", "they", "drive", "to", "Washington", ",", "D.C."] + heads = [1, 1, 1, 4, 5, 2, 1, 10, 10, 10, 10, 10, 11, 12, 12] + deps = ["nsubj", "ROOT", "prep", "compound", "compound", "pobj", "punct", "dep", "advmod", "nsubj", "ROOT", "prep", "pobj", "punct", "appos"] + ents = ["O", "O", "O", "B-GPE", "I-GPE", "I-GPE", "O", "O", "O", "O", "O", "O", "B-GPE", "O", "B-GPE"] + # fmt: on + doc = Doc( + nlp.vocab, + words=words, + spaces=spaces, + tags=tags, + lemmas=lemmas, + heads=heads, + deps=deps, + ents=ents, + ) + assert doc.text == text + example = Example(nlp.make_doc(text), doc) + # whitespace is only added internally in entity spans + mod_ex = make_whitespace_variant(nlp, example, " ", 3) + assert mod_ex.reference.ents[0].text == "New York City" + mod_ex = make_whitespace_variant(nlp, example, " ", 4) + assert mod_ex.reference.ents[0].text == "New York City" + mod_ex = make_whitespace_variant(nlp, example, " ", 5) + assert mod_ex.reference.ents[0].text == "New York City" + mod_ex = make_whitespace_variant(nlp, example, " ", 6) + assert mod_ex.reference.ents[0].text == "New York City" + # add a space at every possible position + for i in range(len(doc) + 1): + mod_ex = make_whitespace_variant(nlp, example, " ", i) + assert mod_ex.reference[i].is_space + # adds annotation when the doc contains at least partial annotation + assert [t.tag_ for t in mod_ex.reference] == tags[:i] + ["_SP"] + tags[i:] + assert [t.lemma_ for t in mod_ex.reference] == lemmas[:i] + [" "] + lemmas[i:] + assert [t.dep_ for t in mod_ex.reference] == deps[:i] + ["dep"] + deps[i:] + # does not add partial annotation if doc does not contain this feature + assert not mod_ex.reference.has_annotation("POS") + assert not mod_ex.reference.has_annotation("MORPH") + # produces well-formed trees + assert not contains_cycle([t.head.i for t in mod_ex.reference]) + assert len(list(doc.sents)) == 2 + if i == 0: + assert mod_ex.reference[i].head.i == 1 + else: + assert mod_ex.reference[i].head.i == i - 1 + # adding another space also produces well-formed trees + for j in (3, 8, 10): + mod_ex2 = make_whitespace_variant(nlp, mod_ex, "\t\t\n", j) + assert not contains_cycle([t.head.i for t in mod_ex2.reference]) + assert len(list(doc.sents)) == 2 + assert mod_ex2.reference[j].head.i == j - 1 + # entities are well-formed + assert len(doc.ents) == len(mod_ex.reference.ents) + for ent in mod_ex.reference.ents: + assert not ent[0].is_space + assert not ent[-1].is_space + + # no modifications if: + # partial dependencies + example.reference[0].dep_ = "" + mod_ex = make_whitespace_variant(nlp, example, " ", 5) + assert mod_ex.text == example.reference.text + example.reference[0].dep_ = "nsubj" # reset + + # spans + example.reference.spans["spans"] = [example.reference[0:5]] + mod_ex = make_whitespace_variant(nlp, example, " ", 5) + assert mod_ex.text == example.reference.text + del example.reference.spans["spans"] # reset + + # links + example.reference.ents = [Span(doc, 0, 2, label="ENT", kb_id="Q123")] + mod_ex = make_whitespace_variant(nlp, example, " ", 5) + assert mod_ex.text == example.reference.text diff --git a/spacy/training/augment.py b/spacy/training/augment.py index 63b54034c..59a39c7ee 100644 --- a/spacy/training/augment.py +++ b/spacy/training/augment.py @@ -1,4 +1,5 @@ from typing import Callable, Iterator, Dict, List, Tuple, TYPE_CHECKING +from typing import Optional import random import itertools from functools import partial @@ -11,32 +12,87 @@ if TYPE_CHECKING: from ..language import Language # noqa: F401 -class OrthVariantsSingle(BaseModel): - tags: List[StrictStr] - variants: List[StrictStr] +@registry.augmenters("spacy.combined_augmenter.v1") +def create_combined_augmenter( + lower_level: float, + orth_level: float, + orth_variants: Optional[Dict[str, List[Dict]]], + whitespace_level: float, + whitespace_per_token: float, + whitespace_variants: Optional[List[str]], +) -> Callable[["Language", Example], Iterator[Example]]: + """Create a data augmentation callback that uses orth-variant replacement. + The callback can be added to a corpus or other data iterator during training. + + lower_level (float): The percentage of texts that will be lowercased. + orth_level (float): The percentage of texts that will be augmented. + orth_variants (Optional[Dict[str, List[Dict]]]): A dictionary containing the + single and paired orth variants. Typically loaded from a JSON file. + whitespace_level (float): The percentage of texts that will have whitespace + tokens inserted. + whitespace_per_token (float): The number of whitespace tokens to insert in + the modified doc as a percentage of the doc length. + whitespace_variants (Optional[List[str]]): The whitespace token texts. + RETURNS (Callable[[Language, Example], Iterator[Example]]): The augmenter. + """ + return partial( + combined_augmenter, + lower_level=lower_level, + orth_level=orth_level, + orth_variants=orth_variants, + whitespace_level=whitespace_level, + whitespace_per_token=whitespace_per_token, + whitespace_variants=whitespace_variants, + ) -class OrthVariantsPaired(BaseModel): - tags: List[StrictStr] - variants: List[List[StrictStr]] - - -class OrthVariants(BaseModel): - paired: List[OrthVariantsPaired] = [] - single: List[OrthVariantsSingle] = [] +def combined_augmenter( + nlp: "Language", + example: Example, + *, + lower_level: float = 0.0, + orth_level: float = 0.0, + orth_variants: Optional[Dict[str, List[Dict]]] = None, + whitespace_level: float = 0.0, + whitespace_per_token: float = 0.0, + whitespace_variants: Optional[List[str]] = None, +) -> Iterator[Example]: + if random.random() < lower_level: + example = make_lowercase_variant(nlp, example) + if orth_variants and random.random() < orth_level: + raw_text = example.text + orig_dict = example.to_dict() + variant_text, variant_token_annot = make_orth_variants( + nlp, + raw_text, + orig_dict["token_annotation"], + orth_variants, + lower=False, + ) + orig_dict["token_annotation"] = variant_token_annot + example = example.from_dict(nlp.make_doc(variant_text), orig_dict) + if whitespace_variants and random.random() < whitespace_level: + for _ in range(int(len(example.reference) * whitespace_per_token)): + example = make_whitespace_variant( + nlp, + example, + random.choice(whitespace_variants), + random.randrange(0, len(example.reference)), + ) + yield example @registry.augmenters("spacy.orth_variants.v1") def create_orth_variants_augmenter( - level: float, lower: float, orth_variants: OrthVariants + level: float, lower: float, orth_variants: Dict[str, List[Dict]] ) -> Callable[["Language", Example], Iterator[Example]]: """Create a data augmentation callback that uses orth-variant replacement. The callback can be added to a corpus or other data iterator during training. level (float): The percentage of texts that will be augmented. lower (float): The percentage of texts that will be lowercased. - orth_variants (Dict[str, dict]): A dictionary containing the single and - paired orth variants. Typically loaded from a JSON file. + orth_variants (Dict[str, List[Dict]]): A dictionary containing + the single and paired orth variants. Typically loaded from a JSON file. RETURNS (Callable[[Language, Example], Iterator[Example]]): The augmenter. """ return partial( @@ -67,16 +123,20 @@ def lower_casing_augmenter( if random.random() >= level: yield example else: - example_dict = example.to_dict() - doc = nlp.make_doc(example.text.lower()) - example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in example.reference] - yield example.from_dict(doc, example_dict) + yield make_lowercase_variant(nlp, example) + + +def make_lowercase_variant(nlp: "Language", example: Example): + example_dict = example.to_dict() + doc = nlp.make_doc(example.text.lower()) + example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in example.reference] + return example.from_dict(doc, example_dict) def orth_variants_augmenter( nlp: "Language", example: Example, - orth_variants: Dict, + orth_variants: Dict[str, List[Dict]], *, level: float = 0.0, lower: float = 0.0, @@ -148,10 +208,132 @@ def make_orth_variants( pair_idx = pair.index(words[word_idx]) words[word_idx] = punct_choices[punct_idx][pair_idx] token_dict["ORTH"] = words - # construct modified raw text from words and spaces + raw = construct_modified_raw_text(token_dict) + return raw, token_dict + + +def make_whitespace_variant( + nlp: "Language", + example: Example, + whitespace: str, + position: int, +) -> Example: + """Insert the whitespace token at the specified token offset in the doc. + This is primarily intended for v2-compatible training data that doesn't + include links or spans. If the document includes links, spans, or partial + dependency annotation, it is returned without modifications. + + The augmentation follows the basics of the v2 space attachment policy, but + without a distinction between "real" and other tokens, so space tokens + may be attached to space tokens: + - at the beginning of a sentence attach the space token to the following + token + - otherwise attach the space token to the preceding token + + The augmenter does not attempt to consolidate adjacent whitespace in the + same way that the tokenizer would. + + The following annotation is used for the space token: + TAG: "_SP" + MORPH: "" + POS: "SPACE" + LEMMA: ORTH + DEP: "dep" + SENT_START: False + + The annotation for each attribute is only set for the space token if there + is already at least partial annotation for that attribute in the original + example. + + RETURNS (Example): Example with one additional space token. + """ + example_dict = example.to_dict() + doc_dict = example_dict.get("doc_annotation", {}) + token_dict = example_dict.get("token_annotation", {}) + # returned unmodified if: + # - doc is empty + # - words are not defined + # - links are defined (only character-based offsets, which is more a quirk + # of Example.to_dict than a technical constraint) + # - spans are defined + # - there are partial dependencies + if ( + len(example.reference) == 0 + or "ORTH" not in token_dict + or len(doc_dict.get("links", [])) > 0 + or len(example.reference.spans) > 0 + or ( + example.reference.has_annotation("DEP") + and not example.reference.has_annotation("DEP", require_complete=True) + ) + ): + return example + words = token_dict.get("ORTH", []) + length = len(words) + assert 0 <= position <= length + if example.reference.has_annotation("ENT_TYPE"): + # I-ENTITY if between B/I-ENTITY and I/L-ENTITY otherwise O + entity = "O" + if position > 1 and position < length: + ent_prev = doc_dict["entities"][position - 1] + ent_next = doc_dict["entities"][position] + if "-" in ent_prev and "-" in ent_next: + ent_iob_prev = ent_prev.split("-")[0] + ent_type_prev = ent_prev.split("-", 1)[1] + ent_iob_next = ent_next.split("-")[0] + ent_type_next = ent_next.split("-", 1)[1] + if ( + ent_iob_prev in ("B", "I") + and ent_iob_next in ("I", "L") + and ent_type_prev == ent_type_next + ): + entity = f"I-{ent_type_prev}" + doc_dict["entities"].insert(position, entity) + else: + del doc_dict["entities"] + token_dict["ORTH"].insert(position, whitespace) + token_dict["SPACY"].insert(position, False) + if example.reference.has_annotation("TAG"): + token_dict["TAG"].insert(position, "_SP") + else: + del token_dict["TAG"] + if example.reference.has_annotation("LEMMA"): + token_dict["LEMMA"].insert(position, whitespace) + else: + del token_dict["LEMMA"] + if example.reference.has_annotation("POS"): + token_dict["POS"].insert(position, "SPACE") + else: + del token_dict["POS"] + if example.reference.has_annotation("MORPH"): + token_dict["MORPH"].insert(position, "") + else: + del token_dict["MORPH"] + if example.reference.has_annotation("DEP", require_complete=True): + if position == 0: + token_dict["HEAD"].insert(position, 0) + else: + token_dict["HEAD"].insert(position, position - 1) + for i in range(len(token_dict["HEAD"])): + if token_dict["HEAD"][i] >= position: + token_dict["HEAD"][i] += 1 + token_dict["DEP"].insert(position, "dep") + else: + del token_dict["HEAD"] + del token_dict["DEP"] + if example.reference.has_annotation("SENT_START"): + token_dict["SENT_START"].insert(position, False) + else: + del token_dict["SENT_START"] + raw = construct_modified_raw_text(token_dict) + return Example.from_dict(nlp.make_doc(raw), example_dict) + + +def construct_modified_raw_text(token_dict): + """Construct modified raw text from words and spaces.""" raw = "" for orth, spacy in zip(token_dict["ORTH"], token_dict["SPACY"]): raw += orth if spacy: raw += " " - return raw, token_dict + return raw