Update CoNLL script. Don't preset SBD. Set batch size to 8, avoid writing twice

This commit is contained in:
Matthew Honnibal 2018-02-22 21:35:50 +01:00
parent a26e399f84
commit 23236340f4

View File

@ -191,12 +191,7 @@ def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False,
docs, golds = read_data(nlp, conllu_file, text_file, docs, golds = read_data(nlp, conllu_file, text_file,
oracle_segments=oracle_segments) oracle_segments=oracle_segments)
if joint_sbd: if joint_sbd:
sbd = nlp.create_pipe('sentencizer') pass
for doc in docs:
doc = sbd(doc)
for sent in doc.sents:
sent[0].is_sent_start = True
#docs = (prevent_bad_sentences(doc) for doc in docs)
else: else:
sbd = nlp.create_pipe('sentencizer') sbd = nlp.create_pipe('sentencizer')
for doc in docs: for doc in docs:
@ -276,8 +271,8 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev
print("Begin training") print("Begin training")
# Batch size starts at 1 and grows, so that we make updates quickly # Batch size starts at 1 and grows, so that we make updates quickly
# at the beginning of training. # at the beginning of training.
batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 2), batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 8),
spacy.util.env_opt('batch_to', 2), spacy.util.env_opt('batch_to', 8),
spacy.util.env_opt('batch_compound', 1.001)) spacy.util.env_opt('batch_compound', 1.001))
for i in range(30): for i in range(30):
docs = refresh_docs(docs) docs = refresh_docs(docs)
@ -288,7 +283,6 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev
if not batch: if not batch:
continue continue
batch_docs, batch_gold = zip(*batch) batch_docs, batch_gold = zip(*batch)
batch_docs = [prevent_bad_sentences(doc) for doc in batch_docs]
nlp.update(batch_docs, batch_gold, sgd=optimizer, nlp.update(batch_docs, batch_gold, sgd=optimizer,
drop=0.2, losses=losses) drop=0.2, losses=losses)
@ -303,8 +297,6 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev
dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc, dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc,
oracle_segments=False, joint_sbd=False) oracle_segments=False, joint_sbd=False)
print_progress(i, losses, scorer) print_progress(i, losses, scorer)
with open(output_loc, 'w') as file_:
print_conllu(dev_docs, file_)
if __name__ == '__main__': if __name__ == '__main__':