Always shuffle gold data, and support length cap

This commit is contained in:
Matthew Honnibal 2017-05-26 11:30:52 -05:00
parent d65f99a720
commit daac3e3573

View File

@ -198,15 +198,15 @@ class GoldCorpus(object):
n += 1
return n
def train_docs(self, nlp, shuffle=0, gold_preproc=False,
projectivize=False):
def train_docs(self, nlp, gold_preproc=False,
projectivize=False, max_length=None):
train_tuples = self.train_tuples
if projectivize:
train_tuples = nonproj.preprocess_training_data(
self.train_tuples)
if shuffle:
random.shuffle(train_tuples)
gold_docs = self.iter_gold_docs(nlp, train_tuples, gold_preproc)
gold_docs = self.iter_gold_docs(nlp, train_tuples, gold_preproc,
max_length=max_length)
yield from gold_docs
def dev_docs(self, nlp, gold_preproc=False):
@ -215,7 +215,7 @@ class GoldCorpus(object):
yield from gold_docs
@classmethod
def iter_gold_docs(cls, nlp, tuples, gold_preproc):
def iter_gold_docs(cls, nlp, tuples, gold_preproc, max_length=None):
for raw_text, paragraph_tuples in tuples:
if gold_preproc:
raw_text = None
@ -226,6 +226,7 @@ class GoldCorpus(object):
gold_preproc)
golds = cls._make_golds(docs, paragraph_tuples)
for doc, gold in zip(docs, golds):
if not max_length or len(doc) < max_length:
yield doc, gold
@classmethod