Update conllu training script

This commit is contained in:
Matthew Honnibal 2018-02-25 13:12:39 +01:00
parent e09070eca7
commit bdb0174571

View File

@ -352,14 +352,15 @@ class TreebankPaths(object):
config=("Path to json formatted config file", "positional", None, Config.load), config=("Path to json formatted config file", "positional", None, Config.load),
corpus=("UD corpus to train and evaluate on, e.g. en, es_ancora, etc", corpus=("UD corpus to train and evaluate on, e.g. en, es_ancora, etc",
"positional", None, str), "positional", None, str),
parses=("Path to write the development parses", "positional", None, Path) parses_loc=("Path to write the development parses", "positional", None, Path),
limit=("Size limit", "option", "n", int)
) )
def main(ud_dir, corpus, config, parses='/tmp/dev.conllu'): def main(ud_dir, corpus, config, parses_loc='/tmp/dev.conllu', limit=10):
paths = TreebankPaths(ud_dir, corpus) paths = TreebankPaths(ud_dir, corpus)
nlp = load_nlp(paths.lang, config) nlp = load_nlp(paths.lang, config)
docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(), docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(),
config) limit=limit)
optimizer = initialize_pipeline(nlp, docs, golds, config) optimizer = initialize_pipeline(nlp, docs, golds, config)
n_train_words = sum(len(doc) for doc in docs) n_train_words = sum(len(doc) for doc in docs)
@ -379,7 +380,7 @@ def main(ud_dir, corpus, config, parses='/tmp/dev.conllu'):
with nlp.use_params(optimizer.averages): with nlp.use_params(optimizer.averages):
dev_docs, scorer = parse_dev_data(nlp, paths.dev.text, paths.dev.conllu) dev_docs, scorer = parse_dev_data(nlp, paths.dev.text, paths.dev.conllu)
print_progress(i, losses, scorer) print_progress(i, losses, scorer)
with open(output_loc, 'w') as file_: with open(parses_loc, 'w') as file_:
print_conllu(dev_docs, file_) print_conllu(dev_docs, file_)