Support max_length in Corpus

This commit is contained in:
Matthew Honnibal 2020-06-23 22:57:40 +02:00
parent d5212f7ba8
commit a68d0e63f0
2 changed files with 31 additions and 17 deletions

View File

@ -301,7 +301,8 @@ def create_train_batches(nlp, corpus, cfg):
train_examples = list(corpus.train_dataset( train_examples = list(corpus.train_dataset(
nlp, nlp,
shuffle=True, shuffle=True,
gold_preproc=cfg["gold_preproc"] gold_preproc=cfg["gold_preproc"],
max_length=cfg["max_length"]
)) ))
if len(train_examples) == 0: if len(train_examples) == 0:
raise ValueError(Errors.E988) raise ValueError(Errors.E988)

View File

@ -43,24 +43,36 @@ class Corpus:
locs.append(path) locs.append(path)
return locs return locs
def make_examples(self, nlp, reference_docs): def make_examples(self, nlp, reference_docs, max_length=0):
for reference in reference_docs: for reference in reference_docs:
predicted = nlp.make_doc(reference.text) if max_length >= 1 and len(reference) >= max_length:
yield Example(predicted, reference) if reference.is_sentenced:
for ref_sent in reference.sents:
yield Example(
nlp.make_doc(ref_sent.text),
ref_sent.as_doc()
)
else:
yield Example(
nlp.make_doc(reference.text),
reference
)
def make_examples_gold_preproc(self, nlp, reference_docs): def make_examples_gold_preproc(self, nlp, reference_docs):
for whole_reference in reference_docs: for reference in reference_docs:
if whole_reference.is_sentenced: if reference.is_sentenced:
references = [sent.as_doc() for sent in whole_reference.sents] ref_sents = [sent.as_doc() for sent in reference.sents]
else: else:
references = [whole_reference] ref_sents = [reference]
for reference in references: for ref_sent in ref_sents:
predicted = Doc( yield Example(
nlp.vocab, Doc(
words=[t.text for t in reference], nlp.vocab,
spaces=[bool(t.whitespace_) for t in reference] words=[w.text for w in ref_sent],
spaces=[bool(w.whitespace_) for w in ref_sent]
),
ref_sent
) )
yield Example(predicted, reference)
def read_docbin(self, vocab, locs): def read_docbin(self, vocab, locs):
""" Yield training examples as example dicts """ """ Yield training examples as example dicts """
@ -86,12 +98,13 @@ class Corpus:
i += 1 i += 1
return n return n
def train_dataset(self, nlp, *, shuffle=True, gold_preproc=False, **kwargs): def train_dataset(self, nlp, *, shuffle=True, gold_preproc=False,
max_length=0, **kwargs):
ref_docs = self.read_docbin(nlp.vocab, self.walk_corpus(self.train_loc)) ref_docs = self.read_docbin(nlp.vocab, self.walk_corpus(self.train_loc))
if gold_preproc: if gold_preproc:
examples = self.make_examples_gold_preproc(nlp, ref_docs) examples = self.make_examples_gold_preproc(nlp, ref_docs)
else: else:
examples = self.make_examples(nlp, ref_docs) examples = self.make_examples(nlp, ref_docs, max_length)
if shuffle: if shuffle:
examples = list(examples) examples = list(examples)
random.shuffle(examples) random.shuffle(examples)
@ -102,5 +115,5 @@ class Corpus:
if gold_preproc: if gold_preproc:
examples = self.make_examples_gold_preproc(nlp, ref_docs) examples = self.make_examples_gold_preproc(nlp, ref_docs)
else: else:
examples = self.make_examples(nlp, ref_docs) examples = self.make_examples(nlp, ref_docs, max_length=0)
yield from examples yield from examples