diff --git a/spacy/gold.pyx b/spacy/gold.pyx index 6b07592cc..57b5dc039 100644 --- a/spacy/gold.pyx +++ b/spacy/gold.pyx @@ -199,14 +199,16 @@ class GoldCorpus(object): return n def train_docs(self, nlp, gold_preproc=False, - projectivize=False, max_length=None): + projectivize=False, max_length=None, + noise_level=0.0): train_tuples = self.train_tuples if projectivize: train_tuples = nonproj.preprocess_training_data( self.train_tuples) random.shuffle(train_tuples) gold_docs = self.iter_gold_docs(nlp, train_tuples, gold_preproc, - max_length=max_length) + max_length=max_length, + noise_level=noise_level) yield from gold_docs def dev_docs(self, nlp, gold_preproc=False): @@ -215,7 +217,8 @@ class GoldCorpus(object): yield from gold_docs @classmethod - def iter_gold_docs(cls, nlp, tuples, gold_preproc, max_length=None): + def iter_gold_docs(cls, nlp, tuples, gold_preproc, max_length=None, + noise_level=0.0): for raw_text, paragraph_tuples in tuples: if gold_preproc: raw_text = None @@ -223,18 +226,20 @@ class GoldCorpus(object): paragraph_tuples = merge_sents(paragraph_tuples) docs = cls._make_docs(nlp, raw_text, paragraph_tuples, - gold_preproc) + gold_preproc, noise_level=noise_level) golds = cls._make_golds(docs, paragraph_tuples) for doc, gold in zip(docs, golds): if (not max_length) or len(doc) < max_length: yield doc, gold @classmethod - def _make_docs(cls, nlp, raw_text, paragraph_tuples, gold_preproc): + def _make_docs(cls, nlp, raw_text, paragraph_tuples, gold_preproc, + noise_level=0.0): if raw_text is not None: + raw_text = add_noise(raw_text, noise_level) return [nlp.make_doc(raw_text)] else: - return [Doc(nlp.vocab, words=sent_tuples[1]) + return [Doc(nlp.vocab, words=add_noise(sent_tuples[1], noise_level)) for (sent_tuples, brackets) in paragraph_tuples] @classmethod @@ -266,6 +271,30 @@ class GoldCorpus(object): return locs +def add_noise(orig, noise_level): + if random.random() >= noise_level: + return orig + elif type(orig) == list: + corrupted = [_corrupt(word, noise_level) for word in orig] + corrupted = [w for w in corrupted if w] + return corrupted + else: + return ''.join(_corrupt(c, noise_level) for c in orig) + + +def _corrupt(c, noise_level): + if random.random() >= noise_level: + return c + elif c == ' ': + return '\n' + elif c == '\n': + return ' ' + elif c in ['.', "'", "!", "?"]: + return '' + else: + return c.lower() + + def read_json_file(loc, docs_filter=None, limit=None): loc = ensure_path(loc) if loc.is_dir():