Refactor conllu script, fix interface, generalize

This commit is contained in:
Matthew Honnibal 2018-02-25 14:54:47 +01:00
parent 551c93fe01
commit 9e960d24fc

View File

@ -13,7 +13,7 @@ import json
import spacy
import spacy.util
from spacy.tokens import Token, Doc
from spacy.gold import GoldParse, minibatch
from spacy.gold import GoldParse
from spacy.syntax.nonproj import projectivize
from collections import defaultdict, Counter
from timeit import default_timer as timer
@ -24,7 +24,7 @@ import random
import numpy.random
import cytoolz
from spacy._align import align
import conll17_ud_eval
random.seed(0)
numpy.random.seed(0)
@ -43,7 +43,8 @@ def minibatch_by_words(items, size=5000):
try:
doc, gold = next(items)
except StopIteration:
yield batch
if batch:
yield batch
return
batch_size -= len(doc)
batch.append((doc, gold))
@ -56,9 +57,9 @@ def minibatch_by_words(items, size=5000):
# Data reading #
################
space_re = re.compile('\s+')
def split_text(text):
return [par.strip().replace('\n', ' ')
for par in text.split('\n\n')]
return [space_re.sub(' ', par.strip()) for par in text.split('\n\n')]
def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False,
@ -132,7 +133,10 @@ def read_conllu(file_):
doc.append(sent)
sent = []
else:
sent.append(line.strip().split())
sent.append(list(line.strip().split('\t')))
if len(sent[-1]) != 10:
print(repr(line))
raise ValueError
if sent:
doc.append(sent)
if doc:
@ -176,50 +180,21 @@ def golds_to_gold_tuples(docs, golds):
# Evaluation #
##############
def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False,
joint_sbd=True, limit=None):
with open(text_loc) as text_file:
with open(conllu_loc) as conllu_file:
docs, golds = read_data(nlp, conllu_file, text_file,
oracle_segments=oracle_segments, limit=limit)
if joint_sbd:
pass
else:
sbd = nlp.create_pipe('sentencizer')
for doc in docs:
doc = sbd(doc)
for sent in doc.sents:
sent[0].is_sent_start = True
for word in sent[1:]:
word.is_sent_start = False
scorer = nlp.evaluate(zip(docs, golds))
return docs, scorer
def evaluate(nlp, text_loc, gold_loc, sys_loc, limit=None):
with text_loc.open('r', encoding='utf8') as text_file:
texts = split_text(text_file.read())
docs = list(nlp.pipe(texts))
with sys_loc.open('w', encoding='utf8') as out_file:
write_conllu(docs, out_file)
with gold_loc.open('r', encoding='utf8') as gold_file:
gold_ud = conll17_ud_eval.load_conllu(gold_file)
with sys_loc.open('r', encoding='utf8') as sys_file:
sys_ud = conll17_ud_eval.load_conllu(sys_file)
scores = conll17_ud_eval.evaluate(gold_ud, sys_ud)
return scores
def print_progress(itn, losses, scorer):
scores = {}
for col in ['dep_loss', 'tag_loss', 'uas', 'tags_acc', 'token_acc',
'ents_p', 'ents_r', 'ents_f', 'cpu_wps', 'gpu_wps']:
scores[col] = 0.0
scores['dep_loss'] = losses.get('parser', 0.0)
scores['ner_loss'] = losses.get('ner', 0.0)
scores['tag_loss'] = losses.get('tagger', 0.0)
scores.update(scorer.scores)
tpl = '\t'.join((
'{:d}',
'{dep_loss:.3f}',
'{ner_loss:.3f}',
'{uas:.3f}',
'{ents_p:.3f}',
'{ents_r:.3f}',
'{ents_f:.3f}',
'{tags_acc:.3f}',
'{token_acc:.3f}',
))
print(tpl.format(itn, **scores))
def print_conllu(docs, file_):
def write_conllu(docs, file_):
merger = Matcher(docs[0].vocab)
merger.add('SUBTOK', None, [{'DEP': 'subtok', 'op': '+'}])
for i, doc in enumerate(docs):
@ -236,6 +211,31 @@ def print_conllu(docs, file_):
file_.write(token._.get_conllu_lines(k) + '\n')
file_.write('\n')
def print_progress(itn, losses, ud_scores):
fields = {
'dep_loss': losses.get('parser', 0.0),
'tag_loss': losses.get('tagger', 0.0),
'words': ud_scores['Words'].f1 * 100,
'sents': ud_scores['Sentences'].f1 * 100,
'tags': ud_scores['XPOS'].f1 * 100,
'uas': ud_scores['UAS'].f1 * 100,
'las': ud_scores['LAS'].f1 * 100,
}
header = ['Epoch', 'Loss', 'LAS', 'UAS', 'TAG', 'SENT', 'WORD']
if itn == 0:
print('\t'.join(header))
tpl = '\t'.join((
'{:d}',
'{dep_loss:.1f}',
'{las:.1f}',
'{uas:.1f}',
'{tags:.1f}',
'{sents:.1f}',
'{words:.1f}',
))
print(tpl.format(itn, **fields))
#def get_sent_conllu(sent, sent_id):
# lines = ["# sent_id = {sent_id}".format(sent_id=sent_id)]
@ -275,7 +275,6 @@ def load_nlp(corpus, config):
return nlp
def initialize_pipeline(nlp, docs, golds, config):
print("Create parser")
nlp.add_pipe(nlp.create_pipe('parser'))
if config.multitask_tag:
nlp.parser.add_multitask_objective('tag')
@ -347,14 +346,16 @@ class TreebankPaths(object):
@plac.annotations(
ud_dir=("Path to Universal Dependencies corpus", "positional", None, Path),
config=("Path to json formatted config file", "positional", None, Config.load),
corpus=("UD corpus to train and evaluate on, e.g. en, es_ancora, etc",
"positional", None, str),
parses_loc=("Path to write the development parses", "positional", None, Path),
parses_dir=("Directory to write the development parses", "positional", None, Path),
config=("Path to json formatted config file", "positional", None, Config.load),
limit=("Size limit", "option", "n", int)
)
def main(ud_dir, corpus, config, parses_loc='/tmp/dev.conllu', limit=10):
def main(ud_dir, parses_dir, config, corpus, limit=0):
paths = TreebankPaths(ud_dir, corpus)
if not (parses_dir / corpus).exists():
(parses_dir / corpus).mkdir()
print("Train and evaluate", corpus, "using lang", paths.lang)
nlp = load_nlp(paths.lang, config)
@ -362,6 +363,7 @@ def main(ud_dir, corpus, config, parses_loc='/tmp/dev.conllu', limit=10):
max_doc_length=config.max_doc_length, limit=limit)
optimizer = initialize_pipeline(nlp, docs, golds, config)
for i in range(config.nr_epoch):
docs = [nlp.make_doc(doc.text) for doc in docs]
batches = minibatch_by_words(list(zip(docs, golds)), size=config.batch_size)
@ -374,11 +376,10 @@ def main(ud_dir, corpus, config, parses_loc='/tmp/dev.conllu', limit=10):
nlp.update(batch_docs, batch_gold, sgd=optimizer,
drop=config.dropout, losses=losses)
out_path = parses_dir / corpus / 'epoch-{i}.conllu'.format(i=i)
with nlp.use_params(optimizer.averages):
dev_docs, scorer = parse_dev_data(nlp, paths.dev.text, paths.dev.conllu)
print_progress(i, losses, scorer)
with open(parses_loc, 'w') as file_:
print_conllu(dev_docs, file_)
scores = evaluate(nlp, paths.dev.text, paths.dev.conllu, out_path)
print_progress(i, losses, scores)
if __name__ == '__main__':