mirror of
https://github.com/explosion/spaCy.git
synced 2025-03-13 07:55:49 +03:00
Support max_length in Corpus
This commit is contained in:
parent
d5212f7ba8
commit
a68d0e63f0
|
@ -301,7 +301,8 @@ def create_train_batches(nlp, corpus, cfg):
|
||||||
train_examples = list(corpus.train_dataset(
|
train_examples = list(corpus.train_dataset(
|
||||||
nlp,
|
nlp,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
gold_preproc=cfg["gold_preproc"]
|
gold_preproc=cfg["gold_preproc"],
|
||||||
|
max_length=cfg["max_length"]
|
||||||
))
|
))
|
||||||
if len(train_examples) == 0:
|
if len(train_examples) == 0:
|
||||||
raise ValueError(Errors.E988)
|
raise ValueError(Errors.E988)
|
||||||
|
|
|
@ -43,24 +43,36 @@ class Corpus:
|
||||||
locs.append(path)
|
locs.append(path)
|
||||||
return locs
|
return locs
|
||||||
|
|
||||||
def make_examples(self, nlp, reference_docs):
|
def make_examples(self, nlp, reference_docs, max_length=0):
|
||||||
for reference in reference_docs:
|
for reference in reference_docs:
|
||||||
predicted = nlp.make_doc(reference.text)
|
if max_length >= 1 and len(reference) >= max_length:
|
||||||
yield Example(predicted, reference)
|
if reference.is_sentenced:
|
||||||
|
for ref_sent in reference.sents:
|
||||||
|
yield Example(
|
||||||
|
nlp.make_doc(ref_sent.text),
|
||||||
|
ref_sent.as_doc()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield Example(
|
||||||
|
nlp.make_doc(reference.text),
|
||||||
|
reference
|
||||||
|
)
|
||||||
|
|
||||||
def make_examples_gold_preproc(self, nlp, reference_docs):
|
def make_examples_gold_preproc(self, nlp, reference_docs):
|
||||||
for whole_reference in reference_docs:
|
for reference in reference_docs:
|
||||||
if whole_reference.is_sentenced:
|
if reference.is_sentenced:
|
||||||
references = [sent.as_doc() for sent in whole_reference.sents]
|
ref_sents = [sent.as_doc() for sent in reference.sents]
|
||||||
else:
|
else:
|
||||||
references = [whole_reference]
|
ref_sents = [reference]
|
||||||
for reference in references:
|
for ref_sent in ref_sents:
|
||||||
predicted = Doc(
|
yield Example(
|
||||||
nlp.vocab,
|
Doc(
|
||||||
words=[t.text for t in reference],
|
nlp.vocab,
|
||||||
spaces=[bool(t.whitespace_) for t in reference]
|
words=[w.text for w in ref_sent],
|
||||||
|
spaces=[bool(w.whitespace_) for w in ref_sent]
|
||||||
|
),
|
||||||
|
ref_sent
|
||||||
)
|
)
|
||||||
yield Example(predicted, reference)
|
|
||||||
|
|
||||||
def read_docbin(self, vocab, locs):
|
def read_docbin(self, vocab, locs):
|
||||||
""" Yield training examples as example dicts """
|
""" Yield training examples as example dicts """
|
||||||
|
@ -86,12 +98,13 @@ class Corpus:
|
||||||
i += 1
|
i += 1
|
||||||
return n
|
return n
|
||||||
|
|
||||||
def train_dataset(self, nlp, *, shuffle=True, gold_preproc=False, **kwargs):
|
def train_dataset(self, nlp, *, shuffle=True, gold_preproc=False,
|
||||||
|
max_length=0, **kwargs):
|
||||||
ref_docs = self.read_docbin(nlp.vocab, self.walk_corpus(self.train_loc))
|
ref_docs = self.read_docbin(nlp.vocab, self.walk_corpus(self.train_loc))
|
||||||
if gold_preproc:
|
if gold_preproc:
|
||||||
examples = self.make_examples_gold_preproc(nlp, ref_docs)
|
examples = self.make_examples_gold_preproc(nlp, ref_docs)
|
||||||
else:
|
else:
|
||||||
examples = self.make_examples(nlp, ref_docs)
|
examples = self.make_examples(nlp, ref_docs, max_length)
|
||||||
if shuffle:
|
if shuffle:
|
||||||
examples = list(examples)
|
examples = list(examples)
|
||||||
random.shuffle(examples)
|
random.shuffle(examples)
|
||||||
|
@ -102,5 +115,5 @@ class Corpus:
|
||||||
if gold_preproc:
|
if gold_preproc:
|
||||||
examples = self.make_examples_gold_preproc(nlp, ref_docs)
|
examples = self.make_examples_gold_preproc(nlp, ref_docs)
|
||||||
else:
|
else:
|
||||||
examples = self.make_examples(nlp, ref_docs)
|
examples = self.make_examples(nlp, ref_docs, max_length=0)
|
||||||
yield from examples
|
yield from examples
|
||||||
|
|
Loading…
Reference in New Issue
Block a user