mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Fix augmenter
This commit is contained in:
parent
549758f67d
commit
d2b9aafb8c
|
@ -120,8 +120,8 @@ def make_orth_variants(
|
|||
ndsv = orth_variants.get("single", [])
|
||||
ndpv = orth_variants.get("paired", [])
|
||||
logger.debug(f"Data augmentation: {len(ndsv)} single / {len(ndpv)} paired variants")
|
||||
words = token_dict.get("words", [])
|
||||
tags = token_dict.get("tags", [])
|
||||
words = token_dict.get("ORTH", [])
|
||||
tags = token_dict.get("TAG", [])
|
||||
# keep unmodified if words or tags are not defined
|
||||
if words and tags:
|
||||
if lower:
|
||||
|
@ -131,7 +131,7 @@ def make_orth_variants(
|
|||
for word_idx in range(len(words)):
|
||||
for punct_idx in range(len(ndsv)):
|
||||
if (
|
||||
tags[word_idx] in ndsv[punct_idx]["tags"]
|
||||
tags[word_idx] in ndsv[punct_idx]["TAG"]
|
||||
and words[word_idx] in ndsv[punct_idx]["variants"]
|
||||
):
|
||||
words[word_idx] = punct_choices[punct_idx]
|
||||
|
@ -139,14 +139,14 @@ def make_orth_variants(
|
|||
punct_choices = [random.choice(x["variants"]) for x in ndpv]
|
||||
for word_idx in range(len(words)):
|
||||
for punct_idx in range(len(ndpv)):
|
||||
if tags[word_idx] in ndpv[punct_idx]["tags"] and words[
|
||||
if tags[word_idx] in ndpv[punct_idx]["TAG"] and words[
|
||||
word_idx
|
||||
] in itertools.chain.from_iterable(ndpv[punct_idx]["variants"]):
|
||||
# backup option: random left vs. right from pair
|
||||
pair_idx = random.choice([0, 1])
|
||||
# best option: rely on paired POS tags like `` / ''
|
||||
if len(ndpv[punct_idx]["tags"]) == 2:
|
||||
pair_idx = ndpv[punct_idx]["tags"].index(tags[word_idx])
|
||||
if len(ndpv[punct_idx]["TAG"]) == 2:
|
||||
pair_idx = ndpv[punct_idx]["TAG"].index(tags[word_idx])
|
||||
# next best option: rely on position in variants
|
||||
# (may not be unambiguous, so order of variants matters)
|
||||
else:
|
||||
|
@ -154,8 +154,8 @@ def make_orth_variants(
|
|||
if words[word_idx] in pair:
|
||||
pair_idx = pair.index(words[word_idx])
|
||||
words[word_idx] = punct_choices[punct_idx][pair_idx]
|
||||
token_dict["words"] = words
|
||||
token_dict["tags"] = tags
|
||||
token_dict["ORTH"] = words
|
||||
token_dict["TAG"] = tags
|
||||
# modify raw
|
||||
if raw is not None:
|
||||
variants = []
|
||||
|
|
Loading…
Reference in New Issue
Block a user