Fix data augmentation

This commit is contained in:
Matthew Honnibal 2020-09-29 23:40:54 +02:00
parent 14c4da547f
commit f52249fe2e

View File

@ -4,6 +4,7 @@ import itertools
import copy import copy
from functools import partial from functools import partial
from ..util import registry from ..util import registry
from ..tokens import Doc
@registry.augmenters("spacy.dont_augment.v1") @registry.augmenters("spacy.dont_augment.v1")
@ -38,10 +39,12 @@ def orth_variants_augmenter(nlp, example, *, level: float = 0.0, lower: float =
orig_dict["token_annotation"], orig_dict["token_annotation"],
lower=raw_text is not None and random.random() < lower, lower=raw_text is not None and random.random() < lower,
) )
if variant_text is None: if variant_text:
doc = Doc(nlp.vocab, words=variant_token_annot["words"])
else:
doc = nlp.make_doc(variant_text) doc = nlp.make_doc(variant_text)
else:
doc = Doc(nlp.vocab, words=variant_token_annot["ORTH"])
variant_token_annot["ORTH"] = [w.text for w in doc]
variant_token_annot["SPACY"] = [w.whitespace_ for w in doc]
orig_dict["token_annotation"] = variant_token_annot orig_dict["token_annotation"] = variant_token_annot
yield example.from_dict(doc, orig_dict) yield example.from_dict(doc, orig_dict)