2018-04-29 16:50:25 +03:00
|
|
|
'''Train for CONLL 2017 UD treebank evaluation. Takes .conllu files, writes
|
|
|
|
.conllu format for development data, allowing the official scorer to be used.
|
|
|
|
'''
|
|
|
|
from __future__ import unicode_literals
|
|
|
|
import plac
|
|
|
|
import tqdm
|
|
|
|
from pathlib import Path
|
|
|
|
import re
|
|
|
|
import sys
|
|
|
|
import json
|
|
|
|
|
|
|
|
import spacy
|
|
|
|
import spacy.util
|
|
|
|
from ..tokens import Token, Doc
|
|
|
|
from ..gold import GoldParse
|
|
|
|
from ..util import compounding, minibatch_by_words
|
|
|
|
from ..syntax.nonproj import projectivize
|
|
|
|
from ..matcher import Matcher
|
2018-05-08 20:40:33 +03:00
|
|
|
#from ..morphology import Fused_begin, Fused_inside
|
2018-04-29 16:50:25 +03:00
|
|
|
from .. import displacy
|
|
|
|
from collections import defaultdict, Counter
|
|
|
|
from timeit import default_timer as timer
|
2018-05-09 01:28:03 +03:00
|
|
|
Fused_begin = None
|
|
|
|
Fused_inside = None
|
2018-04-29 16:50:25 +03:00
|
|
|
|
|
|
|
import itertools
|
|
|
|
import random
|
|
|
|
import numpy.random
|
|
|
|
import cytoolz
|
|
|
|
|
|
|
|
from . import conll17_ud_eval
|
|
|
|
|
|
|
|
from .. import lang
|
|
|
|
from .. import lang
|
|
|
|
from ..lang import zh
|
|
|
|
from ..lang import ja
|
|
|
|
from ..lang import ru
|
|
|
|
|
|
|
|
|
|
|
|
################
|
|
|
|
# Data reading #
|
|
|
|
################
|
|
|
|
|
|
|
|
space_re = re.compile('\s+')
|
|
|
|
def split_text(text):
|
|
|
|
return [space_re.sub(' ', par.strip()) for par in text.split('\n\n')]
|
|
|
|
|
|
|
|
|
|
|
|
##############
|
|
|
|
# Evaluation #
|
|
|
|
##############
|
|
|
|
|
|
|
|
def read_conllu(file_):
|
|
|
|
docs = []
|
|
|
|
sent = []
|
|
|
|
doc = []
|
|
|
|
for line in file_:
|
|
|
|
if line.startswith('# newdoc'):
|
|
|
|
if doc:
|
|
|
|
docs.append(doc)
|
|
|
|
doc = []
|
|
|
|
elif line.startswith('#'):
|
|
|
|
continue
|
|
|
|
elif not line.strip():
|
|
|
|
if sent:
|
|
|
|
doc.append(sent)
|
|
|
|
sent = []
|
|
|
|
else:
|
|
|
|
sent.append(list(line.strip().split('\t')))
|
|
|
|
if len(sent[-1]) != 10:
|
|
|
|
print(repr(line))
|
|
|
|
raise ValueError
|
|
|
|
if sent:
|
|
|
|
doc.append(sent)
|
|
|
|
if doc:
|
|
|
|
docs.append(doc)
|
|
|
|
return docs
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate(nlp, text_loc, gold_loc, sys_loc, limit=None):
|
|
|
|
if text_loc.parts[-1].endswith('.conllu'):
|
|
|
|
docs = []
|
|
|
|
with text_loc.open() as file_:
|
|
|
|
for conllu_doc in read_conllu(file_):
|
|
|
|
for conllu_sent in conllu_doc:
|
|
|
|
words = [line[1] for line in conllu_sent]
|
|
|
|
docs.append(Doc(nlp.vocab, words=words))
|
|
|
|
for name, component in nlp.pipeline:
|
|
|
|
docs = list(component.pipe(docs))
|
|
|
|
else:
|
|
|
|
with text_loc.open('r', encoding='utf8') as text_file:
|
|
|
|
texts = split_text(text_file.read())
|
|
|
|
docs = list(nlp.pipe(texts))
|
|
|
|
with sys_loc.open('w', encoding='utf8') as out_file:
|
|
|
|
write_conllu(docs, out_file)
|
|
|
|
with gold_loc.open('r', encoding='utf8') as gold_file:
|
|
|
|
gold_ud = conll17_ud_eval.load_conllu(gold_file)
|
|
|
|
with sys_loc.open('r', encoding='utf8') as sys_file:
|
|
|
|
sys_ud = conll17_ud_eval.load_conllu(sys_file)
|
|
|
|
scores = conll17_ud_eval.evaluate(gold_ud, sys_ud)
|
|
|
|
return docs, scores
|
|
|
|
|
|
|
|
|
|
|
|
def write_conllu(docs, file_):
|
|
|
|
merger = Matcher(docs[0].vocab)
|
|
|
|
merger.add('SUBTOK', None, [{'DEP': 'subtok', 'op': '+'}])
|
|
|
|
for i, doc in enumerate(docs):
|
|
|
|
matches = merger(doc)
|
|
|
|
spans = [doc[start:end+1] for _, start, end in matches]
|
|
|
|
offsets = [(span.start_char, span.end_char) for span in spans]
|
|
|
|
for start_char, end_char in offsets:
|
|
|
|
doc.merge(start_char, end_char)
|
|
|
|
# TODO: This shuldn't be necessary? Should be handled in merge
|
|
|
|
for word in doc:
|
|
|
|
if word.i == word.head.i:
|
|
|
|
word.dep_ = 'ROOT'
|
|
|
|
file_.write("# newdoc id = {i}\n".format(i=i))
|
|
|
|
for j, sent in enumerate(doc.sents):
|
|
|
|
file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j))
|
|
|
|
file_.write("# text = {text}\n".format(text=sent.text))
|
|
|
|
for k, token in enumerate(sent):
|
|
|
|
file_.write(_get_token_conllu(token, k, len(sent)) + '\n')
|
|
|
|
file_.write('\n')
|
|
|
|
for word in sent:
|
|
|
|
if word.head.i == word.i and word.dep_ == 'ROOT':
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
print("Rootless sentence!")
|
|
|
|
print(sent)
|
|
|
|
print(i)
|
|
|
|
for w in sent:
|
|
|
|
print(w.i, w.text, w.head.text, w.head.i, w.dep_)
|
|
|
|
raise ValueError
|
|
|
|
|
|
|
|
|
|
|
|
def _get_token_conllu(token, k, sent_len):
|
|
|
|
if token.check_morph(Fused_begin) and (k+1 < sent_len):
|
|
|
|
n = 1
|
|
|
|
text = [token.text]
|
|
|
|
while token.nbor(n).check_morph(Fused_inside):
|
|
|
|
text.append(token.nbor(n).text)
|
|
|
|
n += 1
|
|
|
|
id_ = '%d-%d' % (k+1, (k+n))
|
|
|
|
fields = [id_, ''.join(text)] + ['_'] * 8
|
|
|
|
lines = ['\t'.join(fields)]
|
|
|
|
else:
|
|
|
|
lines = []
|
|
|
|
if token.head.i == token.i:
|
|
|
|
head = 0
|
|
|
|
else:
|
|
|
|
head = k + (token.head.i - token.i) + 1
|
|
|
|
fields = [str(k+1), token.text, token.lemma_, token.pos_, token.tag_, '_',
|
|
|
|
str(head), token.dep_.lower(), '_', '_']
|
|
|
|
if token.check_morph(Fused_begin) and (k+1 < sent_len):
|
|
|
|
if k == 0:
|
|
|
|
fields[1] = token.norm_[0].upper() + token.norm_[1:]
|
|
|
|
else:
|
|
|
|
fields[1] = token.norm_
|
|
|
|
elif token.check_morph(Fused_inside):
|
|
|
|
fields[1] = token.norm_
|
|
|
|
elif token._.split_start is not None:
|
|
|
|
split_start = token._.split_start
|
|
|
|
split_end = token._.split_end
|
|
|
|
split_len = (split_end.i - split_start.i) + 1
|
|
|
|
n_in_split = token.i - split_start.i
|
|
|
|
subtokens = guess_fused_orths(split_start.text, [''] * split_len)
|
|
|
|
fields[1] = subtokens[n_in_split]
|
|
|
|
|
|
|
|
lines.append('\t'.join(fields))
|
|
|
|
return '\n'.join(lines)
|
|
|
|
|
|
|
|
|
|
|
|
def guess_fused_orths(word, ud_forms):
|
|
|
|
'''The UD data 'fused tokens' don't necessarily expand to keys that match
|
|
|
|
the form. We need orths that exact match the string. Here we make a best
|
|
|
|
effort to divide up the word.'''
|
|
|
|
if word == ''.join(ud_forms):
|
|
|
|
# Happy case: we get a perfect split, with each letter accounted for.
|
|
|
|
return ud_forms
|
|
|
|
elif len(word) == sum(len(subtoken) for subtoken in ud_forms):
|
|
|
|
# Unideal, but at least lengths match.
|
|
|
|
output = []
|
|
|
|
remain = word
|
|
|
|
for subtoken in ud_forms:
|
|
|
|
assert len(subtoken) >= 1
|
|
|
|
output.append(remain[:len(subtoken)])
|
|
|
|
remain = remain[len(subtoken):]
|
|
|
|
assert len(remain) == 0, (word, ud_forms, remain)
|
|
|
|
return output
|
|
|
|
else:
|
|
|
|
# Let's say word is 6 long, and there are three subtokens. The orths
|
|
|
|
# *must* equal the original string. Arbitrarily, split [4, 1, 1]
|
|
|
|
first = word[:len(word)-(len(ud_forms)-1)]
|
|
|
|
output = [first]
|
|
|
|
remain = word[len(first):]
|
|
|
|
for i in range(1, len(ud_forms)):
|
|
|
|
assert remain
|
|
|
|
output.append(remain[:1])
|
|
|
|
remain = remain[1:]
|
|
|
|
assert len(remain) == 0, (word, output, remain)
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def print_results(name, ud_scores):
|
|
|
|
fields = {}
|
|
|
|
if ud_scores is not None:
|
|
|
|
fields.update({
|
|
|
|
'words': ud_scores['Words'].f1 * 100,
|
|
|
|
'sents': ud_scores['Sentences'].f1 * 100,
|
|
|
|
'tags': ud_scores['XPOS'].f1 * 100,
|
|
|
|
'uas': ud_scores['UAS'].f1 * 100,
|
|
|
|
'las': ud_scores['LAS'].f1 * 100,
|
|
|
|
})
|
|
|
|
else:
|
|
|
|
fields.update({
|
|
|
|
'words': 0.0,
|
|
|
|
'sents': 0.0,
|
|
|
|
'tags': 0.0,
|
|
|
|
'uas': 0.0,
|
|
|
|
'las': 0.0
|
|
|
|
})
|
|
|
|
tpl = '\t'.join((
|
|
|
|
name,
|
|
|
|
'{las:.1f}',
|
|
|
|
'{uas:.1f}',
|
|
|
|
'{tags:.1f}',
|
|
|
|
'{sents:.1f}',
|
|
|
|
'{words:.1f}',
|
|
|
|
))
|
|
|
|
print(tpl.format(**fields))
|
|
|
|
return fields
|
|
|
|
|
|
|
|
|
|
|
|
def get_token_split_start(token):
|
|
|
|
if token.text == '':
|
|
|
|
assert token.i != 0
|
|
|
|
i = -1
|
|
|
|
while token.nbor(i).text == '':
|
|
|
|
i -= 1
|
|
|
|
return token.nbor(i)
|
|
|
|
elif (token.i+1) < len(token.doc) and token.nbor(1).text == '':
|
|
|
|
return token
|
|
|
|
else:
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
def get_token_split_end(token):
|
|
|
|
if (token.i+1) == len(token.doc):
|
|
|
|
return token if token.text == '' else None
|
|
|
|
elif token.text != '' and token.nbor(1).text != '':
|
|
|
|
return None
|
|
|
|
i = 1
|
|
|
|
while (token.i+i) < len(token.doc) and token.nbor(i).text == '':
|
|
|
|
i += 1
|
|
|
|
return token.nbor(i-1)
|
|
|
|
|
|
|
|
|
|
|
|
##################
|
|
|
|
# Initialization #
|
|
|
|
##################
|
|
|
|
|
|
|
|
|
|
|
|
def load_nlp(experiments_dir, corpus):
|
|
|
|
nlp = spacy.load(experiments_dir / corpus / 'best-model')
|
|
|
|
return nlp
|
|
|
|
|
|
|
|
def initialize_pipeline(nlp, docs, golds, config, device):
|
|
|
|
nlp.add_pipe(nlp.create_pipe('parser'))
|
|
|
|
return nlp
|
|
|
|
|
|
|
|
|
|
|
|
@plac.annotations(
|
|
|
|
test_data_dir=("Path to Universal Dependencies test data", "positional", None, Path),
|
|
|
|
experiment_dir=("Parent directory with output model", "positional", None, Path),
|
|
|
|
corpus=("UD corpus to evaluate, e.g. UD_English, UD_Spanish, etc", "positional", None, str),
|
|
|
|
)
|
|
|
|
def main(test_data_dir, experiment_dir, corpus):
|
2018-05-09 01:43:00 +03:00
|
|
|
Token.set_extension('split_start', getter=get_token_split_start)
|
|
|
|
Token.set_extension('split_end', getter=get_token_split_end)
|
|
|
|
Token.set_extension('begins_fused', default=False)
|
|
|
|
Token.set_extension('inside_fused', default=False)
|
2018-04-29 16:50:25 +03:00
|
|
|
lang.zh.Chinese.Defaults.use_jieba = False
|
|
|
|
lang.ja.Japanese.Defaults.use_janome = False
|
|
|
|
lang.ru.Russian.Defaults.use_pymorphy2 = False
|
|
|
|
|
|
|
|
nlp = load_nlp(experiment_dir, corpus)
|
|
|
|
|
|
|
|
treebank_code = nlp.meta['treebank']
|
|
|
|
for section in ('test', 'dev'):
|
|
|
|
if section == 'dev':
|
|
|
|
section_dir = 'conll17-ud-development-2017-03-19'
|
|
|
|
else:
|
|
|
|
section_dir = 'conll17-ud-test-2017-05-09'
|
|
|
|
text_path = test_data_dir / 'input' / section_dir / (treebank_code+'.txt')
|
|
|
|
udpipe_path = test_data_dir / 'input' / section_dir / (treebank_code+'-udpipe.conllu')
|
|
|
|
gold_path = test_data_dir / 'gold' / section_dir / (treebank_code+'.conllu')
|
|
|
|
|
|
|
|
header = [section, 'LAS', 'UAS', 'TAG', 'SENT', 'WORD']
|
|
|
|
print('\t'.join(header))
|
|
|
|
inputs = {'gold': gold_path, 'udp': udpipe_path, 'raw': text_path}
|
|
|
|
for input_type in ('udp', 'raw'):
|
|
|
|
input_path = inputs[input_type]
|
|
|
|
output_path = experiment_dir / corpus / '{section}.conllu'.format(section=section)
|
|
|
|
|
|
|
|
parsed_docs, test_scores = evaluate(nlp, input_path, gold_path, output_path)
|
|
|
|
|
|
|
|
accuracy = print_results(input_type, test_scores)
|
|
|
|
acc_path = experiment_dir / corpus / '{section}-accuracy.json'.format(section=section)
|
|
|
|
with open(acc_path, 'w') as file_:
|
|
|
|
file_.write(json.dumps(accuracy, indent=2))
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
plac.call(main)
|