From 9e960d24fcf6aeb0439abd669093398e102ce82c Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 25 Feb 2018 14:54:47 +0100 Subject: [PATCH] Refactor conllu script, fix interface, generalize --- examples/training/conllu.py | 113 ++++++++++++++++++------------------ 1 file changed, 57 insertions(+), 56 deletions(-) diff --git a/examples/training/conllu.py b/examples/training/conllu.py index 7f8c817d2..f7c9b5fef 100644 --- a/examples/training/conllu.py +++ b/examples/training/conllu.py @@ -13,7 +13,7 @@ import json import spacy import spacy.util from spacy.tokens import Token, Doc -from spacy.gold import GoldParse, minibatch +from spacy.gold import GoldParse from spacy.syntax.nonproj import projectivize from collections import defaultdict, Counter from timeit import default_timer as timer @@ -24,7 +24,7 @@ import random import numpy.random import cytoolz -from spacy._align import align +import conll17_ud_eval random.seed(0) numpy.random.seed(0) @@ -43,7 +43,8 @@ def minibatch_by_words(items, size=5000): try: doc, gold = next(items) except StopIteration: - yield batch + if batch: + yield batch return batch_size -= len(doc) batch.append((doc, gold)) @@ -56,9 +57,9 @@ def minibatch_by_words(items, size=5000): # Data reading # ################ +space_re = re.compile('\s+') def split_text(text): - return [par.strip().replace('\n', ' ') - for par in text.split('\n\n')] + return [space_re.sub(' ', par.strip()) for par in text.split('\n\n')] def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False, @@ -132,7 +133,10 @@ def read_conllu(file_): doc.append(sent) sent = [] else: - sent.append(line.strip().split()) + sent.append(list(line.strip().split('\t'))) + if len(sent[-1]) != 10: + print(repr(line)) + raise ValueError if sent: doc.append(sent) if doc: @@ -176,50 +180,21 @@ def golds_to_gold_tuples(docs, golds): # Evaluation # ############## -def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False, - joint_sbd=True, limit=None): - with open(text_loc) as text_file: - with open(conllu_loc) as conllu_file: - docs, golds = read_data(nlp, conllu_file, text_file, - oracle_segments=oracle_segments, limit=limit) - if joint_sbd: - pass - else: - sbd = nlp.create_pipe('sentencizer') - for doc in docs: - doc = sbd(doc) - for sent in doc.sents: - sent[0].is_sent_start = True - for word in sent[1:]: - word.is_sent_start = False - scorer = nlp.evaluate(zip(docs, golds)) - return docs, scorer +def evaluate(nlp, text_loc, gold_loc, sys_loc, limit=None): + with text_loc.open('r', encoding='utf8') as text_file: + texts = split_text(text_file.read()) + docs = list(nlp.pipe(texts)) + with sys_loc.open('w', encoding='utf8') as out_file: + write_conllu(docs, out_file) + with gold_loc.open('r', encoding='utf8') as gold_file: + gold_ud = conll17_ud_eval.load_conllu(gold_file) + with sys_loc.open('r', encoding='utf8') as sys_file: + sys_ud = conll17_ud_eval.load_conllu(sys_file) + scores = conll17_ud_eval.evaluate(gold_ud, sys_ud) + return scores -def print_progress(itn, losses, scorer): - scores = {} - for col in ['dep_loss', 'tag_loss', 'uas', 'tags_acc', 'token_acc', - 'ents_p', 'ents_r', 'ents_f', 'cpu_wps', 'gpu_wps']: - scores[col] = 0.0 - scores['dep_loss'] = losses.get('parser', 0.0) - scores['ner_loss'] = losses.get('ner', 0.0) - scores['tag_loss'] = losses.get('tagger', 0.0) - scores.update(scorer.scores) - tpl = '\t'.join(( - '{:d}', - '{dep_loss:.3f}', - '{ner_loss:.3f}', - '{uas:.3f}', - '{ents_p:.3f}', - '{ents_r:.3f}', - '{ents_f:.3f}', - '{tags_acc:.3f}', - '{token_acc:.3f}', - )) - print(tpl.format(itn, **scores)) - - -def print_conllu(docs, file_): +def write_conllu(docs, file_): merger = Matcher(docs[0].vocab) merger.add('SUBTOK', None, [{'DEP': 'subtok', 'op': '+'}]) for i, doc in enumerate(docs): @@ -236,6 +211,31 @@ def print_conllu(docs, file_): file_.write(token._.get_conllu_lines(k) + '\n') file_.write('\n') + +def print_progress(itn, losses, ud_scores): + fields = { + 'dep_loss': losses.get('parser', 0.0), + 'tag_loss': losses.get('tagger', 0.0), + 'words': ud_scores['Words'].f1 * 100, + 'sents': ud_scores['Sentences'].f1 * 100, + 'tags': ud_scores['XPOS'].f1 * 100, + 'uas': ud_scores['UAS'].f1 * 100, + 'las': ud_scores['LAS'].f1 * 100, + } + header = ['Epoch', 'Loss', 'LAS', 'UAS', 'TAG', 'SENT', 'WORD'] + if itn == 0: + print('\t'.join(header)) + tpl = '\t'.join(( + '{:d}', + '{dep_loss:.1f}', + '{las:.1f}', + '{uas:.1f}', + '{tags:.1f}', + '{sents:.1f}', + '{words:.1f}', + )) + print(tpl.format(itn, **fields)) + #def get_sent_conllu(sent, sent_id): # lines = ["# sent_id = {sent_id}".format(sent_id=sent_id)] @@ -275,7 +275,6 @@ def load_nlp(corpus, config): return nlp def initialize_pipeline(nlp, docs, golds, config): - print("Create parser") nlp.add_pipe(nlp.create_pipe('parser')) if config.multitask_tag: nlp.parser.add_multitask_objective('tag') @@ -347,14 +346,16 @@ class TreebankPaths(object): @plac.annotations( ud_dir=("Path to Universal Dependencies corpus", "positional", None, Path), - config=("Path to json formatted config file", "positional", None, Config.load), corpus=("UD corpus to train and evaluate on, e.g. en, es_ancora, etc", "positional", None, str), - parses_loc=("Path to write the development parses", "positional", None, Path), + parses_dir=("Directory to write the development parses", "positional", None, Path), + config=("Path to json formatted config file", "positional", None, Config.load), limit=("Size limit", "option", "n", int) ) -def main(ud_dir, corpus, config, parses_loc='/tmp/dev.conllu', limit=10): +def main(ud_dir, parses_dir, config, corpus, limit=0): paths = TreebankPaths(ud_dir, corpus) + if not (parses_dir / corpus).exists(): + (parses_dir / corpus).mkdir() print("Train and evaluate", corpus, "using lang", paths.lang) nlp = load_nlp(paths.lang, config) @@ -362,6 +363,7 @@ def main(ud_dir, corpus, config, parses_loc='/tmp/dev.conllu', limit=10): max_doc_length=config.max_doc_length, limit=limit) optimizer = initialize_pipeline(nlp, docs, golds, config) + for i in range(config.nr_epoch): docs = [nlp.make_doc(doc.text) for doc in docs] batches = minibatch_by_words(list(zip(docs, golds)), size=config.batch_size) @@ -374,11 +376,10 @@ def main(ud_dir, corpus, config, parses_loc='/tmp/dev.conllu', limit=10): nlp.update(batch_docs, batch_gold, sgd=optimizer, drop=config.dropout, losses=losses) + out_path = parses_dir / corpus / 'epoch-{i}.conllu'.format(i=i) with nlp.use_params(optimizer.averages): - dev_docs, scorer = parse_dev_data(nlp, paths.dev.text, paths.dev.conllu) - print_progress(i, losses, scorer) - with open(parses_loc, 'w') as file_: - print_conllu(dev_docs, file_) + scores = evaluate(nlp, paths.dev.text, paths.dev.conllu, out_path) + print_progress(i, losses, scores) if __name__ == '__main__':