Shuffle data after each epoch. Improve script

This commit is contained in:
Matthew Honnibal 2018-02-25 13:35:32 +01:00
parent bdb0174571
commit 551c93fe01

View File

@ -30,6 +30,7 @@ random.seed(0)
numpy.random.seed(0)
def minibatch_by_words(items, size=5000):
random.shuffle(items)
if isinstance(size, int):
size_ = itertools.repeat(size)
else:
@ -39,10 +40,17 @@ def minibatch_by_words(items, size=5000):
batch_size = next(size_)
batch = []
while batch_size >= 0:
try:
doc, gold = next(items)
except StopIteration:
yield batch
return
batch_size -= len(doc)
batch.append((doc, gold))
if batch:
yield batch
else:
break
################
# Data reading #
@ -146,10 +154,6 @@ def _make_gold(nlp, text, sent_annots):
doc = nlp.make_doc(text)
flat.pop('spaces')
gold = GoldParse(doc, **flat)
#for annot in gold.orig_annot:
# print(annot)
#for i in range(len(doc)):
# print(doc[i].text, gold.words[i], gold.labels[i], gold.heads[i])
return doc, gold
#############################
@ -168,12 +172,6 @@ def golds_to_gold_tuples(docs, golds):
return tuples
def refresh_docs(docs):
vocab = docs[0].vocab
return [Doc(vocab, words=[t.text for t in doc],
spaces=[t.whitespace_ for t in doc])
for doc in docs]
##############
# Evaluation #
##############
@ -357,23 +355,22 @@ class TreebankPaths(object):
)
def main(ud_dir, corpus, config, parses_loc='/tmp/dev.conllu', limit=10):
paths = TreebankPaths(ud_dir, corpus)
print("Train and evaluate", corpus, "using lang", paths.lang)
nlp = load_nlp(paths.lang, config)
docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(),
limit=limit)
max_doc_length=config.max_doc_length, limit=limit)
optimizer = initialize_pipeline(nlp, docs, golds, config)
n_train_words = sum(len(doc) for doc in docs)
print("Begin training (%d words)" % n_train_words)
for i in range(config.nr_epoch):
docs = refresh_docs(docs)
docs = [nlp.make_doc(doc.text) for doc in docs]
batches = minibatch_by_words(list(zip(docs, golds)), size=config.batch_size)
losses = {}
for batch in tqdm.tqdm(batches, total=n_train_words//config.batch_size):
if not batch:
continue
n_train_words = sum(len(doc) for doc in docs)
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
for batch in batches:
batch_docs, batch_gold = zip(*batch)
pbar.update(sum(len(doc) for doc in batch_docs))
nlp.update(batch_docs, batch_gold, sgd=optimizer,
drop=config.dropout, losses=losses)