mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	Fix lowercase augmentation (#7336)
* Fix aborted/skipped augmentation for `spacy.orth_variants.v1` if lowercasing was enabled for an example * Simplify `spacy.orth_variants.v1` for `Example` vs. `GoldParse` * Preserve reference tokenization in `spacy.lower_case.v1`
This commit is contained in:
		
							parent
							
								
									cd70c3cb79
								
							
						
					
					
						commit
						3f3e8110dc
					
				|  | @ -38,19 +38,59 @@ def doc(nlp): | |||
| 
 | ||||
| 
 | ||||
| @pytest.mark.filterwarnings("ignore::UserWarning") | ||||
| def test_make_orth_variants(nlp, doc): | ||||
| def test_make_orth_variants(nlp): | ||||
|     single = [ | ||||
|         {"tags": ["NFP"], "variants": ["…", "..."]}, | ||||
|         {"tags": [":"], "variants": ["-", "—", "–", "--", "---", "——"]}, | ||||
|     ] | ||||
|     # fmt: off | ||||
|     words = ["\n\n", "A", "\t", "B", "a", "b", "…", "...", "-", "—", "–", "--", "---", "——"] | ||||
|     tags = ["_SP", "NN", "\t", "NN", "NN", "NN", "NFP", "NFP", ":", ":", ":", ":", ":", ":"] | ||||
|     # fmt: on | ||||
|     spaces = [True] * len(words) | ||||
|     spaces[0] = False | ||||
|     spaces[2] = False | ||||
|     doc = Doc(nlp.vocab, words=words, spaces=spaces, tags=tags) | ||||
|     augmenter = create_orth_variants_augmenter( | ||||
|         level=0.2, lower=0.5, orth_variants={"single": single} | ||||
|     ) | ||||
|     with make_docbin([doc]) as output_file: | ||||
|     with make_docbin([doc] * 10) as output_file: | ||||
|         reader = Corpus(output_file, augmenter=augmenter) | ||||
|         # Due to randomness, only test that it works without errors for now | ||||
|         # Due to randomness, only test that it works without errors | ||||
|         list(reader(nlp)) | ||||
| 
 | ||||
|     # check that the following settings lowercase everything | ||||
|     augmenter = create_orth_variants_augmenter( | ||||
|         level=1.0, lower=1.0, orth_variants={"single": single} | ||||
|     ) | ||||
|     with make_docbin([doc] * 10) as output_file: | ||||
|         reader = Corpus(output_file, augmenter=augmenter) | ||||
|         for example in reader(nlp): | ||||
|             for token in example.reference: | ||||
|                 assert token.text == token.text.lower() | ||||
| 
 | ||||
|     # check that lowercasing is applied without tags | ||||
|     doc = Doc(nlp.vocab, words=words, spaces=[True] * len(words)) | ||||
|     augmenter = create_orth_variants_augmenter( | ||||
|         level=1.0, lower=1.0, orth_variants={"single": single} | ||||
|     ) | ||||
|     with make_docbin([doc] * 10) as output_file: | ||||
|         reader = Corpus(output_file, augmenter=augmenter) | ||||
|         for example in reader(nlp): | ||||
|             for ex_token, doc_token in zip(example.reference, doc): | ||||
|                 assert ex_token.text == doc_token.text.lower() | ||||
| 
 | ||||
|     # check that no lowercasing is applied with lower=0.0 | ||||
|     doc = Doc(nlp.vocab, words=words, spaces=[True] * len(words)) | ||||
|     augmenter = create_orth_variants_augmenter( | ||||
|         level=1.0, lower=0.0, orth_variants={"single": single} | ||||
|     ) | ||||
|     with make_docbin([doc] * 10) as output_file: | ||||
|         reader = Corpus(output_file, augmenter=augmenter) | ||||
|         for example in reader(nlp): | ||||
|             for ex_token, doc_token in zip(example.reference, doc): | ||||
|                 assert ex_token.text == doc_token.text | ||||
| 
 | ||||
| 
 | ||||
| def test_lowercase_augmenter(nlp, doc): | ||||
|     augmenter = create_lower_casing_augmenter(level=1.0) | ||||
|  | @ -66,6 +106,21 @@ def test_lowercase_augmenter(nlp, doc): | |||
|         assert ref_ent.text == orig_ent.text.lower() | ||||
|     assert [t.pos_ for t in eg.reference] == [t.pos_ for t in doc] | ||||
| 
 | ||||
|     # check that augmentation works when lowercasing leads to different | ||||
|     # predicted tokenization | ||||
|     words = ["A", "B", "CCC."] | ||||
|     doc = Doc(nlp.vocab, words=words) | ||||
|     with make_docbin([doc]) as output_file: | ||||
|         reader = Corpus(output_file, augmenter=augmenter) | ||||
|         corpus = list(reader(nlp)) | ||||
|     eg = corpus[0] | ||||
|     assert eg.reference.text == doc.text.lower() | ||||
|     assert eg.predicted.text == doc.text.lower() | ||||
|     assert [t.text for t in eg.reference] == [t.lower() for t in words] | ||||
|     assert [t.text for t in eg.predicted] == [ | ||||
|         t.text for t in nlp.make_doc(doc.text.lower()) | ||||
|     ] | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.filterwarnings("ignore::UserWarning") | ||||
| def test_custom_data_augmentation(nlp, doc): | ||||
|  |  | |||
|  | @ -1,12 +1,10 @@ | |||
| from typing import Callable, Iterator, Dict, List, Tuple, TYPE_CHECKING | ||||
| import random | ||||
| import itertools | ||||
| import copy | ||||
| from functools import partial | ||||
| from pydantic import BaseModel, StrictStr | ||||
| 
 | ||||
| from ..util import registry | ||||
| from ..tokens import Doc | ||||
| from .example import Example | ||||
| 
 | ||||
| if TYPE_CHECKING: | ||||
|  | @ -71,7 +69,7 @@ def lower_casing_augmenter( | |||
|     else: | ||||
|         example_dict = example.to_dict() | ||||
|         doc = nlp.make_doc(example.text.lower()) | ||||
|         example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in doc] | ||||
|         example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in example.reference] | ||||
|         yield example.from_dict(doc, example_dict) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -88,24 +86,15 @@ def orth_variants_augmenter( | |||
|     else: | ||||
|         raw_text = example.text | ||||
|         orig_dict = example.to_dict() | ||||
|         if not orig_dict["token_annotation"]: | ||||
|             yield example | ||||
|         else: | ||||
|             variant_text, variant_token_annot = make_orth_variants( | ||||
|                 nlp, | ||||
|                 raw_text, | ||||
|                 orig_dict["token_annotation"], | ||||
|                 orth_variants, | ||||
|                 lower=raw_text is not None and random.random() < lower, | ||||
|             ) | ||||
|             if variant_text: | ||||
|                 doc = nlp.make_doc(variant_text) | ||||
|             else: | ||||
|                 doc = Doc(nlp.vocab, words=variant_token_annot["ORTH"]) | ||||
|                 variant_token_annot["ORTH"] = [w.text for w in doc] | ||||
|                 variant_token_annot["SPACY"] = [w.whitespace_ for w in doc] | ||||
|             orig_dict["token_annotation"] = variant_token_annot | ||||
|             yield example.from_dict(doc, orig_dict) | ||||
|         variant_text, variant_token_annot = make_orth_variants( | ||||
|             nlp, | ||||
|             raw_text, | ||||
|             orig_dict["token_annotation"], | ||||
|             orth_variants, | ||||
|             lower=raw_text is not None and random.random() < lower, | ||||
|         ) | ||||
|         orig_dict["token_annotation"] = variant_token_annot | ||||
|         yield example.from_dict(nlp.make_doc(variant_text), orig_dict) | ||||
| 
 | ||||
| 
 | ||||
| def make_orth_variants( | ||||
|  | @ -116,88 +105,53 @@ def make_orth_variants( | |||
|     *, | ||||
|     lower: bool = False, | ||||
| ) -> Tuple[str, Dict[str, List[str]]]: | ||||
|     orig_token_dict = copy.deepcopy(token_dict) | ||||
|     ndsv = orth_variants.get("single", []) | ||||
|     ndpv = orth_variants.get("paired", []) | ||||
|     words = token_dict.get("ORTH", []) | ||||
|     tags = token_dict.get("TAG", []) | ||||
|     # keep unmodified if words or tags are not defined | ||||
|     if words and tags: | ||||
|         if lower: | ||||
|             words = [w.lower() for w in words] | ||||
|         # single variants | ||||
|         punct_choices = [random.choice(x["variants"]) for x in ndsv] | ||||
|         for word_idx in range(len(words)): | ||||
|             for punct_idx in range(len(ndsv)): | ||||
|                 if ( | ||||
|                     tags[word_idx] in ndsv[punct_idx]["tags"] | ||||
|                     and words[word_idx] in ndsv[punct_idx]["variants"] | ||||
|                 ): | ||||
|                     words[word_idx] = punct_choices[punct_idx] | ||||
|         # paired variants | ||||
|         punct_choices = [random.choice(x["variants"]) for x in ndpv] | ||||
|         for word_idx in range(len(words)): | ||||
|             for punct_idx in range(len(ndpv)): | ||||
|                 if tags[word_idx] in ndpv[punct_idx]["tags"] and words[ | ||||
|                     word_idx | ||||
|                 ] in itertools.chain.from_iterable(ndpv[punct_idx]["variants"]): | ||||
|                     # backup option: random left vs. right from pair | ||||
|                     pair_idx = random.choice([0, 1]) | ||||
|                     # best option: rely on paired POS tags like `` / '' | ||||
|                     if len(ndpv[punct_idx]["tags"]) == 2: | ||||
|                         pair_idx = ndpv[punct_idx]["tags"].index(tags[word_idx]) | ||||
|                     # next best option: rely on position in variants | ||||
|                     # (may not be unambiguous, so order of variants matters) | ||||
|                     else: | ||||
|                         for pair in ndpv[punct_idx]["variants"]: | ||||
|                             if words[word_idx] in pair: | ||||
|                                 pair_idx = pair.index(words[word_idx]) | ||||
|                     words[word_idx] = punct_choices[punct_idx][pair_idx] | ||||
|     # keep unmodified if words are not defined | ||||
|     if not words: | ||||
|         return raw, token_dict | ||||
|     if lower: | ||||
|         words = [w.lower() for w in words] | ||||
|         raw = raw.lower() | ||||
|     # if no tags, only lowercase | ||||
|     if not tags: | ||||
|         token_dict["ORTH"] = words | ||||
|         token_dict["TAG"] = tags | ||||
|     # modify raw | ||||
|     if raw is not None: | ||||
|         variants = [] | ||||
|         for single_variants in ndsv: | ||||
|             variants.extend(single_variants["variants"]) | ||||
|         for paired_variants in ndpv: | ||||
|             variants.extend( | ||||
|                 list(itertools.chain.from_iterable(paired_variants["variants"])) | ||||
|             ) | ||||
|         # store variants in reverse length order to be able to prioritize | ||||
|         # longer matches (e.g., "---" before "--") | ||||
|         variants = sorted(variants, key=lambda x: len(x)) | ||||
|         variants.reverse() | ||||
|         variant_raw = "" | ||||
|         raw_idx = 0 | ||||
|         # add initial whitespace | ||||
|         while raw_idx < len(raw) and raw[raw_idx].isspace(): | ||||
|             variant_raw += raw[raw_idx] | ||||
|             raw_idx += 1 | ||||
|         for word in words: | ||||
|             match_found = False | ||||
|             # skip whitespace words | ||||
|             if word.isspace(): | ||||
|                 match_found = True | ||||
|             # add identical word | ||||
|             elif word not in variants and raw[raw_idx:].startswith(word): | ||||
|                 variant_raw += word | ||||
|                 raw_idx += len(word) | ||||
|                 match_found = True | ||||
|             # add variant word | ||||
|             else: | ||||
|                 for variant in variants: | ||||
|                     if not match_found and raw[raw_idx:].startswith(variant): | ||||
|                         raw_idx += len(variant) | ||||
|                         variant_raw += word | ||||
|                         match_found = True | ||||
|             # something went wrong, abort | ||||
|             # (add a warning message?) | ||||
|             if not match_found: | ||||
|                 return raw, orig_token_dict | ||||
|             # add following whitespace | ||||
|             while raw_idx < len(raw) and raw[raw_idx].isspace(): | ||||
|                 variant_raw += raw[raw_idx] | ||||
|                 raw_idx += 1 | ||||
|         raw = variant_raw | ||||
|         return raw, token_dict | ||||
|     # single variants | ||||
|     ndsv = orth_variants.get("single", []) | ||||
|     punct_choices = [random.choice(x["variants"]) for x in ndsv] | ||||
|     for word_idx in range(len(words)): | ||||
|         for punct_idx in range(len(ndsv)): | ||||
|             if ( | ||||
|                 tags[word_idx] in ndsv[punct_idx]["tags"] | ||||
|                 and words[word_idx] in ndsv[punct_idx]["variants"] | ||||
|             ): | ||||
|                 words[word_idx] = punct_choices[punct_idx] | ||||
|     # paired variants | ||||
|     ndpv = orth_variants.get("paired", []) | ||||
|     punct_choices = [random.choice(x["variants"]) for x in ndpv] | ||||
|     for word_idx in range(len(words)): | ||||
|         for punct_idx in range(len(ndpv)): | ||||
|             if tags[word_idx] in ndpv[punct_idx]["tags"] and words[ | ||||
|                 word_idx | ||||
|             ] in itertools.chain.from_iterable(ndpv[punct_idx]["variants"]): | ||||
|                 # backup option: random left vs. right from pair | ||||
|                 pair_idx = random.choice([0, 1]) | ||||
|                 # best option: rely on paired POS tags like `` / '' | ||||
|                 if len(ndpv[punct_idx]["tags"]) == 2: | ||||
|                     pair_idx = ndpv[punct_idx]["tags"].index(tags[word_idx]) | ||||
|                 # next best option: rely on position in variants | ||||
|                 # (may not be unambiguous, so order of variants matters) | ||||
|                 else: | ||||
|                     for pair in ndpv[punct_idx]["variants"]: | ||||
|                         if words[word_idx] in pair: | ||||
|                             pair_idx = pair.index(words[word_idx]) | ||||
|                 words[word_idx] = punct_choices[punct_idx][pair_idx] | ||||
|     token_dict["ORTH"] = words | ||||
|     # construct modified raw text from words and spaces | ||||
|     raw = "" | ||||
|     for orth, spacy in zip(token_dict["ORTH"], token_dict["SPACY"]): | ||||
|         raw += orth | ||||
|         if spacy: | ||||
|             raw += " " | ||||
|     return raw, token_dict | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user