spaCy/spacy/gold/augment.py

112 lines
4.5 KiB
Python
Raw Normal View History

2020-06-06 15:19:06 +03:00
import random
import itertools
2020-06-19 15:05:35 +03:00
def make_orth_variants_example(nlp, example, orth_variant_level=0.0): # TODO: naming
raw_text = example.text
orig_dict = example.to_dict()
2020-06-22 02:11:43 +03:00
variant_text, variant_token_annot = make_orth_variants(
nlp, raw_text, orig_dict["token_annotation"], orth_variant_level
)
2020-06-19 15:05:35 +03:00
doc = nlp.make_doc(variant_text)
orig_dict["token_annotation"] = variant_token_annot
return example.from_dict(doc, orig_dict)
2020-06-17 11:46:29 +03:00
def make_orth_variants(nlp, raw_text, orig_token_dict, orth_variant_level=0.0):
2020-06-06 15:19:06 +03:00
if random.random() >= orth_variant_level:
2020-06-17 11:46:29 +03:00
return raw_text, orig_token_dict
if not orig_token_dict:
return raw_text, orig_token_dict
raw = raw_text
token_dict = orig_token_dict
2020-06-06 15:19:06 +03:00
lower = False
if random.random() >= 0.5:
lower = True
if raw is not None:
raw = raw.lower()
ndsv = nlp.Defaults.single_orth_variants
ndpv = nlp.Defaults.paired_orth_variants
2020-06-17 11:46:29 +03:00
words = token_dict.get("words", [])
tags = token_dict.get("tags", [])
# keep unmodified if words or tags are not defined
if words and tags:
2020-06-06 15:19:06 +03:00
if lower:
words = [w.lower() for w in words]
# single variants
punct_choices = [random.choice(x["variants"]) for x in ndsv]
for word_idx in range(len(words)):
for punct_idx in range(len(ndsv)):
2020-06-06 16:13:07 +03:00
if (
tags[word_idx] in ndsv[punct_idx]["tags"]
and words[word_idx] in ndsv[punct_idx]["variants"]
):
2020-06-06 15:19:06 +03:00
words[word_idx] = punct_choices[punct_idx]
# paired variants
punct_choices = [random.choice(x["variants"]) for x in ndpv]
for word_idx in range(len(words)):
for punct_idx in range(len(ndpv)):
2020-06-06 16:13:07 +03:00
if tags[word_idx] in ndpv[punct_idx]["tags"] and words[
word_idx
] in itertools.chain.from_iterable(ndpv[punct_idx]["variants"]):
2020-06-06 15:19:06 +03:00
# backup option: random left vs. right from pair
pair_idx = random.choice([0, 1])
# best option: rely on paired POS tags like `` / ''
if len(ndpv[punct_idx]["tags"]) == 2:
pair_idx = ndpv[punct_idx]["tags"].index(tags[word_idx])
# next best option: rely on position in variants
# (may not be unambiguous, so order of variants matters)
else:
for pair in ndpv[punct_idx]["variants"]:
if words[word_idx] in pair:
pair_idx = pair.index(words[word_idx])
words[word_idx] = punct_choices[punct_idx][pair_idx]
token_dict["words"] = words
token_dict["tags"] = tags
2020-06-17 11:46:29 +03:00
# modify raw
2020-06-06 15:19:06 +03:00
if raw is not None:
variants = []
for single_variants in ndsv:
variants.extend(single_variants["variants"])
for paired_variants in ndpv:
2020-06-06 16:13:07 +03:00
variants.extend(
list(itertools.chain.from_iterable(paired_variants["variants"]))
)
2020-06-06 15:19:06 +03:00
# store variants in reverse length order to be able to prioritize
# longer matches (e.g., "---" before "--")
variants = sorted(variants, key=lambda x: len(x))
variants.reverse()
variant_raw = ""
raw_idx = 0
# add initial whitespace
while raw_idx < len(raw) and raw[raw_idx].isspace():
variant_raw += raw[raw_idx]
raw_idx += 1
2020-06-17 11:46:29 +03:00
for word in words:
2020-06-06 15:19:06 +03:00
match_found = False
# skip whitespace words
if word.isspace():
match_found = True
# add identical word
elif word not in variants and raw[raw_idx:].startswith(word):
variant_raw += word
raw_idx += len(word)
match_found = True
# add variant word
else:
for variant in variants:
2020-06-06 16:13:07 +03:00
if not match_found and raw[raw_idx:].startswith(variant):
2020-06-06 15:19:06 +03:00
raw_idx += len(variant)
variant_raw += word
match_found = True
# something went wrong, abort
# (add a warning message?)
if not match_found:
2020-06-17 11:46:29 +03:00
return raw_text, orig_token_dict
2020-06-06 15:19:06 +03:00
# add following whitespace
while raw_idx < len(raw) and raw[raw_idx].isspace():
variant_raw += raw[raw_idx]
raw_idx += 1
2020-06-17 11:46:29 +03:00
raw = variant_raw
return raw, token_dict