diff --git a/spacy/cli/ud_train.py b/spacy/cli/ud_train.py index 86fdabca4..6e00f3193 100644 --- a/spacy/cli/ud_train.py +++ b/spacy/cli/ud_train.py @@ -161,9 +161,19 @@ def golds_to_gold_tuples(docs, golds): ############## def evaluate(nlp, text_loc, gold_loc, sys_loc, limit=None): - with text_loc.open('r', encoding='utf8') as text_file: - texts = split_text(text_file.read()) - docs = list(nlp.pipe(texts)) + if text_loc.parts[-1].endswith('.conllu'): + docs = [] + with text_loc.open() as file_: + for conllu_doc in read_conllu(file_): + for conllu_sent in conllu_doc: + words = [line[1] for line in conllu_sent] + docs.append(Doc(nlp.vocab, words=words)) + for name, component in nlp.pipeline: + docs = list(component.pipe(docs)) + else: + with text_loc.open('r', encoding='utf8') as text_file: + texts = split_text(text_file.read()) + docs = list(nlp.pipe(texts)) with sys_loc.open('w', encoding='utf8') as out_file: write_conllu(docs, out_file) with gold_loc.open('r', encoding='utf8') as gold_file: @@ -261,12 +271,12 @@ def load_nlp(corpus, config, vectors=None): def initialize_pipeline(nlp, docs, golds, config, device): + nlp.add_pipe(nlp.create_pipe('tagger')) nlp.add_pipe(nlp.create_pipe('parser')) if config.multitask_tag: nlp.parser.add_multitask_objective('tag') if config.multitask_sent: nlp.parser.add_multitask_objective('sent_start') - nlp.add_pipe(nlp.create_pipe('tagger')) for gold in golds: for tag in gold.tags: if tag is not None: @@ -328,10 +338,12 @@ class TreebankPaths(object): config=("Path to json formatted config file", "positional"), limit=("Size limit", "option", "n", int), use_gpu=("Use GPU", "option", "g", int), + use_oracle_segments=("Use oracle segments", "flag", "G", int), vectors_dir=("Path to directory with pre-trained vectors, named e.g. en/", "option", "v", Path), ) -def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1, vectors_dir=None): +def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1, vectors_dir=None, + use_oracle_segments=False): spacy.util.fix_random_seed() lang.zh.Chinese.Defaults.use_jieba = False lang.ja.Japanese.Defaults.use_janome = False @@ -344,13 +356,17 @@ def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1, vectors_dir=No nlp = load_nlp(paths.lang, config, vectors=vectors_dir) docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(), - max_doc_length=config.max_doc_length, limit=limit) + max_doc_length=None, limit=limit) optimizer = initialize_pipeline(nlp, docs, golds, config, use_gpu) batch_sizes = compounding(config.batch_size//10, config.batch_size, 1.001) + nlp.parser.cfg['beam_update_prob'] = 1.0 for i in range(config.nr_epoch): - docs = [nlp.make_doc(doc.text) for doc in docs] + docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(), + max_doc_length=config.max_doc_length, limit=limit, + oracle_segments=use_oracle_segments, + raw_text=not use_oracle_segments) Xs = list(zip(docs, golds)) random.shuffle(Xs) batches = minibatch_by_words(Xs, size=batch_sizes) @@ -365,7 +381,12 @@ def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1, vectors_dir=No out_path = parses_dir / corpus / 'epoch-{i}.conllu'.format(i=i) with nlp.use_params(optimizer.averages): - parsed_docs, scores = evaluate(nlp, paths.dev.text, paths.dev.conllu, out_path) + if use_oracle_segments: + parsed_docs, scores = evaluate(nlp, paths.dev.conllu, + paths.dev.conllu, out_path) + else: + parsed_docs, scores = evaluate(nlp, paths.dev.text, + paths.dev.conllu, out_path) print_progress(i, losses, scores) _render_parses(i, parsed_docs[:50])