Always shuffle gold data, and support length cap

This commit is contained in:
Matthew Honnibal 2017-05-26 11:30:52 -05:00
parent d65f99a720
commit daac3e3573

View File

@ -198,15 +198,15 @@ class GoldCorpus(object):
n += 1 n += 1
return n return n
def train_docs(self, nlp, shuffle=0, gold_preproc=False, def train_docs(self, nlp, gold_preproc=False,
projectivize=False): projectivize=False, max_length=None):
train_tuples = self.train_tuples train_tuples = self.train_tuples
if projectivize: if projectivize:
train_tuples = nonproj.preprocess_training_data( train_tuples = nonproj.preprocess_training_data(
self.train_tuples) self.train_tuples)
if shuffle: random.shuffle(train_tuples)
random.shuffle(train_tuples) gold_docs = self.iter_gold_docs(nlp, train_tuples, gold_preproc,
gold_docs = self.iter_gold_docs(nlp, train_tuples, gold_preproc) max_length=max_length)
yield from gold_docs yield from gold_docs
def dev_docs(self, nlp, gold_preproc=False): def dev_docs(self, nlp, gold_preproc=False):
@ -215,7 +215,7 @@ class GoldCorpus(object):
yield from gold_docs yield from gold_docs
@classmethod @classmethod
def iter_gold_docs(cls, nlp, tuples, gold_preproc): def iter_gold_docs(cls, nlp, tuples, gold_preproc, max_length=None):
for raw_text, paragraph_tuples in tuples: for raw_text, paragraph_tuples in tuples:
if gold_preproc: if gold_preproc:
raw_text = None raw_text = None
@ -226,7 +226,8 @@ class GoldCorpus(object):
gold_preproc) gold_preproc)
golds = cls._make_golds(docs, paragraph_tuples) golds = cls._make_golds(docs, paragraph_tuples)
for doc, gold in zip(docs, golds): for doc, gold in zip(docs, golds):
yield doc, gold if not max_length or len(doc) < max_length:
yield doc, gold
@classmethod @classmethod
def _make_docs(cls, nlp, raw_text, paragraph_tuples, gold_preproc): def _make_docs(cls, nlp, raw_text, paragraph_tuples, gold_preproc):