2015-01-09 20:53:26 +03:00
|
|
|
#!/usr/bin/env python
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import unicode_literals
|
|
|
|
|
|
|
|
import os
|
|
|
|
from os import path
|
|
|
|
import shutil
|
|
|
|
import codecs
|
|
|
|
import random
|
|
|
|
import time
|
|
|
|
import gzip
|
|
|
|
|
|
|
|
import plac
|
|
|
|
import cProfile
|
|
|
|
import pstats
|
|
|
|
|
|
|
|
import spacy.util
|
|
|
|
from spacy.en import English
|
|
|
|
from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir
|
|
|
|
|
|
|
|
from spacy.syntax.parser import GreedyParser
|
2015-02-02 15:02:48 +03:00
|
|
|
from spacy.syntax.parser import OracleError
|
2015-01-09 20:53:26 +03:00
|
|
|
from spacy.syntax.util import Config
|
|
|
|
|
|
|
|
|
2015-01-30 08:36:24 +03:00
|
|
|
def is_punct_label(label):
|
|
|
|
return label == 'P' or label.lower() == 'punct'
|
|
|
|
|
|
|
|
|
2015-01-09 20:53:26 +03:00
|
|
|
def read_tokenized_gold(file_):
|
|
|
|
"""Read a standard CoNLL/MALT-style format"""
|
|
|
|
sents = []
|
|
|
|
for sent_str in file_.read().strip().split('\n\n'):
|
2015-01-28 20:21:13 +03:00
|
|
|
ids = []
|
2015-01-09 20:53:26 +03:00
|
|
|
words = []
|
|
|
|
heads = []
|
|
|
|
labels = []
|
|
|
|
tags = []
|
|
|
|
for i, line in enumerate(sent_str.split('\n')):
|
2015-02-02 15:02:48 +03:00
|
|
|
id_, word, pos_string, head_idx, label = _parse_line(line)
|
2015-01-09 20:53:26 +03:00
|
|
|
words.append(word)
|
|
|
|
if head_idx == -1:
|
|
|
|
head_idx = i
|
2015-01-28 20:21:13 +03:00
|
|
|
ids.append(id_)
|
2015-01-09 20:53:26 +03:00
|
|
|
heads.append(head_idx)
|
|
|
|
labels.append(label)
|
|
|
|
tags.append(pos_string)
|
2015-02-02 15:02:48 +03:00
|
|
|
text = ' '.join(words)
|
|
|
|
sents.append((text, [words], ids, words, tags, heads, labels))
|
2015-01-09 20:53:26 +03:00
|
|
|
return sents
|
|
|
|
|
|
|
|
|
|
|
|
def read_docparse_gold(file_):
|
2015-01-30 02:31:03 +03:00
|
|
|
paragraphs = []
|
2015-02-09 11:57:56 +03:00
|
|
|
for sent_str in file_.read().strip().split('\n\n'):
|
2015-02-02 15:02:48 +03:00
|
|
|
if not sent_str.strip():
|
|
|
|
continue
|
2015-01-09 20:53:26 +03:00
|
|
|
words = []
|
|
|
|
heads = []
|
|
|
|
labels = []
|
|
|
|
tags = []
|
2015-01-28 20:21:13 +03:00
|
|
|
ids = []
|
2015-01-09 20:53:26 +03:00
|
|
|
lines = sent_str.strip().split('\n')
|
2015-02-18 06:02:09 +03:00
|
|
|
raw_text = lines.pop(0).strip()
|
|
|
|
tok_text = lines.pop(0).strip()
|
2015-02-02 15:02:48 +03:00
|
|
|
for i, line in enumerate(lines):
|
2015-01-28 20:21:13 +03:00
|
|
|
id_, word, pos_string, head_idx, label = _parse_line(line)
|
|
|
|
if label == 'root':
|
|
|
|
label = 'ROOT'
|
2015-01-09 20:53:26 +03:00
|
|
|
words.append(word)
|
2015-01-28 20:21:13 +03:00
|
|
|
if head_idx < 0:
|
|
|
|
head_idx = id_
|
|
|
|
ids.append(id_)
|
2015-01-09 20:53:26 +03:00
|
|
|
heads.append(head_idx)
|
|
|
|
labels.append(label)
|
|
|
|
tags.append(pos_string)
|
2015-01-30 02:31:03 +03:00
|
|
|
tokenized = [sent_str.replace('<SEP>', ' ').split(' ')
|
|
|
|
for sent_str in tok_text.split('<SENT>')]
|
|
|
|
paragraphs.append((raw_text, tokenized, ids, words, tags, heads, labels))
|
|
|
|
return paragraphs
|
|
|
|
|
2015-01-09 20:53:26 +03:00
|
|
|
|
2015-01-28 20:21:13 +03:00
|
|
|
def _map_indices_to_tokens(ids, heads):
|
2015-01-30 02:31:03 +03:00
|
|
|
mapped = []
|
|
|
|
for head in heads:
|
|
|
|
if head not in ids:
|
|
|
|
mapped.append(None)
|
|
|
|
else:
|
|
|
|
mapped.append(ids.index(head))
|
|
|
|
return mapped
|
2015-01-28 20:21:13 +03:00
|
|
|
|
|
|
|
|
2015-01-09 20:53:26 +03:00
|
|
|
def _parse_line(line):
|
|
|
|
pieces = line.split()
|
|
|
|
if len(pieces) == 4:
|
2015-01-28 20:21:13 +03:00
|
|
|
return 0, pieces[0], pieces[1], int(pieces[2]) - 1, pieces[3]
|
2015-01-09 20:53:26 +03:00
|
|
|
else:
|
2015-01-28 20:21:13 +03:00
|
|
|
id_ = int(pieces[0])
|
2015-01-09 20:53:26 +03:00
|
|
|
word = pieces[1]
|
|
|
|
pos = pieces[3]
|
2015-01-28 20:21:13 +03:00
|
|
|
head_idx = int(pieces[6])
|
2015-01-09 20:53:26 +03:00
|
|
|
label = pieces[7]
|
2015-01-28 20:21:13 +03:00
|
|
|
return id_, word, pos, head_idx, label
|
2015-01-09 20:53:26 +03:00
|
|
|
|
2015-01-30 02:31:03 +03:00
|
|
|
|
2015-01-30 08:36:24 +03:00
|
|
|
loss = 0
|
2015-01-30 02:31:03 +03:00
|
|
|
def _align_annotations_to_non_gold_tokens(tokens, words, annot):
|
2015-01-30 08:36:24 +03:00
|
|
|
global loss
|
2015-01-30 02:31:03 +03:00
|
|
|
tags = []
|
|
|
|
heads = []
|
|
|
|
labels = []
|
2015-01-30 08:36:24 +03:00
|
|
|
orig_words = list(words)
|
|
|
|
missed = []
|
2015-01-30 02:31:03 +03:00
|
|
|
for token in tokens:
|
|
|
|
while annot and token.idx > annot[0][0]:
|
2015-01-30 08:36:24 +03:00
|
|
|
miss_id, miss_tag, miss_head, miss_label = annot.pop(0)
|
|
|
|
miss_w = words.pop(0)
|
|
|
|
if not is_punct_label(miss_label):
|
|
|
|
missed.append(miss_w)
|
|
|
|
loss += 1
|
2015-01-30 02:31:03 +03:00
|
|
|
if not annot:
|
|
|
|
tags.append(None)
|
|
|
|
heads.append(None)
|
|
|
|
labels.append(None)
|
|
|
|
continue
|
|
|
|
id_, tag, head, label = annot[0]
|
|
|
|
if token.idx == id_:
|
|
|
|
tags.append(tag)
|
|
|
|
heads.append(head)
|
|
|
|
labels.append(label)
|
|
|
|
annot.pop(0)
|
|
|
|
words.pop(0)
|
|
|
|
elif token.idx < id_:
|
|
|
|
tags.append(None)
|
|
|
|
heads.append(None)
|
|
|
|
labels.append(None)
|
|
|
|
else:
|
|
|
|
raise StandardError
|
2015-01-30 08:36:24 +03:00
|
|
|
#if missed:
|
|
|
|
# print orig_words
|
|
|
|
# print missed
|
|
|
|
# for t in tokens:
|
|
|
|
# print t.idx, t.orth_
|
2015-01-30 02:31:03 +03:00
|
|
|
return loss, tags, heads, labels
|
|
|
|
|
|
|
|
|
|
|
|
def iter_data(paragraphs, tokenizer, gold_preproc=False):
|
|
|
|
for raw, tokenized, ids, words, tags, heads, labels in paragraphs:
|
|
|
|
if not gold_preproc:
|
|
|
|
tokens = tokenizer(raw)
|
|
|
|
loss, tags, heads, labels = _align_annotations_to_non_gold_tokens(
|
2015-01-30 08:36:24 +03:00
|
|
|
tokens, list(words),
|
|
|
|
zip(ids, tags, heads, labels))
|
2015-01-30 02:31:03 +03:00
|
|
|
ids = [t.idx for t in tokens]
|
|
|
|
heads = _map_indices_to_tokens(ids, heads)
|
|
|
|
yield tokens, tags, heads, labels
|
|
|
|
else:
|
|
|
|
assert len(words) == len(heads)
|
|
|
|
for words in tokenized:
|
|
|
|
sent_ids = ids[:len(words)]
|
|
|
|
sent_tags = tags[:len(words)]
|
|
|
|
sent_heads = heads[:len(words)]
|
|
|
|
sent_labels = labels[:len(words)]
|
|
|
|
sent_heads = _map_indices_to_tokens(sent_ids, sent_heads)
|
|
|
|
tokens = tokenizer.tokens_from_list(words)
|
|
|
|
yield tokens, sent_tags, sent_heads, sent_labels
|
|
|
|
ids = ids[len(words):]
|
|
|
|
tags = tags[len(words):]
|
|
|
|
heads = heads[len(words):]
|
|
|
|
labels = labels[len(words):]
|
|
|
|
|
|
|
|
|
2015-01-09 20:53:26 +03:00
|
|
|
def get_labels(sents):
|
|
|
|
left_labels = set()
|
|
|
|
right_labels = set()
|
2015-01-30 02:31:03 +03:00
|
|
|
for raw, tokenized, ids, words, tags, heads, labels in sents:
|
2015-01-09 20:53:26 +03:00
|
|
|
for child, (head, label) in enumerate(zip(heads, labels)):
|
|
|
|
if head > child:
|
|
|
|
left_labels.add(label)
|
|
|
|
elif head < child:
|
|
|
|
right_labels.add(label)
|
|
|
|
return list(sorted(left_labels)), list(sorted(right_labels))
|
|
|
|
|
|
|
|
|
2015-01-30 02:31:03 +03:00
|
|
|
def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
2015-02-02 15:02:48 +03:00
|
|
|
gold_preproc=False, force_gold=False):
|
2015-01-09 20:53:26 +03:00
|
|
|
dep_model_dir = path.join(model_dir, 'deps')
|
|
|
|
pos_model_dir = path.join(model_dir, 'pos')
|
|
|
|
if path.exists(dep_model_dir):
|
|
|
|
shutil.rmtree(dep_model_dir)
|
|
|
|
if path.exists(pos_model_dir):
|
|
|
|
shutil.rmtree(pos_model_dir)
|
|
|
|
os.mkdir(dep_model_dir)
|
|
|
|
os.mkdir(pos_model_dir)
|
|
|
|
setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES,
|
|
|
|
pos_model_dir)
|
|
|
|
|
2015-02-22 04:06:29 +03:00
|
|
|
labels = Language.ParserTransitionSystem.get_labels(gold_sents)
|
2015-01-09 20:53:26 +03:00
|
|
|
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
2015-02-22 04:06:29 +03:00
|
|
|
labels=labels)
|
2015-01-09 20:53:26 +03:00
|
|
|
|
|
|
|
nlp = Language()
|
|
|
|
|
|
|
|
for itn in range(n_iter):
|
|
|
|
heads_corr = 0
|
|
|
|
pos_corr = 0
|
|
|
|
n_tokens = 0
|
2015-02-22 04:06:29 +03:00
|
|
|
for gold_sent in gold_sents:
|
|
|
|
tokens = nlp.tokenizer(gold_sent.raw)
|
|
|
|
gold_sent.align_to_tokens(tokens)
|
2015-01-09 20:53:26 +03:00
|
|
|
nlp.tagger(tokens)
|
2015-02-22 04:06:29 +03:00
|
|
|
heads_corr += nlp.parser.train(tokens, gold_sent, force_gold=force_gold)
|
|
|
|
pos_corr += nlp.tagger.train(tokens, gold_parse.tags)
|
2015-01-09 20:53:26 +03:00
|
|
|
n_tokens += len(tokens)
|
|
|
|
acc = float(heads_corr) / n_tokens
|
|
|
|
pos_acc = float(pos_corr) / n_tokens
|
|
|
|
print '%d: ' % itn, '%.3f' % acc, '%.3f' % pos_acc
|
2015-01-30 02:31:03 +03:00
|
|
|
random.shuffle(paragraphs)
|
2015-01-09 20:53:26 +03:00
|
|
|
nlp.parser.model.end_training()
|
|
|
|
nlp.tagger.model.end_training()
|
|
|
|
return acc
|
|
|
|
|
|
|
|
|
2015-01-30 02:31:03 +03:00
|
|
|
def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
|
2015-01-31 05:44:37 +03:00
|
|
|
global loss
|
2015-01-09 20:53:26 +03:00
|
|
|
nlp = Language()
|
2015-03-03 12:35:11 +03:00
|
|
|
uas_corr = 0
|
|
|
|
las_corr = 0
|
2015-02-02 15:02:48 +03:00
|
|
|
pos_corr = 0
|
|
|
|
n_tokens = 0
|
2015-01-09 20:53:26 +03:00
|
|
|
total = 0
|
2015-01-30 02:31:03 +03:00
|
|
|
skipped = 0
|
2015-01-31 05:44:37 +03:00
|
|
|
loss = 0
|
2015-01-09 20:53:26 +03:00
|
|
|
with codecs.open(dev_loc, 'r', 'utf8') as file_:
|
2015-02-18 06:02:09 +03:00
|
|
|
#paragraphs = read_tokenized_gold(file_)
|
2015-01-30 02:31:03 +03:00
|
|
|
paragraphs = read_docparse_gold(file_)
|
|
|
|
for tokens, tag_strs, heads, labels in iter_data(paragraphs, nlp.tokenizer,
|
|
|
|
gold_preproc=gold_preproc):
|
|
|
|
assert len(tokens) == len(labels)
|
2015-01-09 20:53:26 +03:00
|
|
|
nlp.tagger(tokens)
|
2015-01-24 18:20:49 +03:00
|
|
|
nlp.parser(tokens)
|
2015-01-09 20:53:26 +03:00
|
|
|
for i, token in enumerate(tokens):
|
2015-02-18 06:02:09 +03:00
|
|
|
pos_corr += token.tag_ == tag_strs[i]
|
2015-02-02 15:02:48 +03:00
|
|
|
n_tokens += 1
|
2015-01-30 02:31:03 +03:00
|
|
|
if heads[i] is None:
|
|
|
|
skipped += 1
|
2015-01-30 08:36:24 +03:00
|
|
|
continue
|
|
|
|
if is_punct_label(labels[i]):
|
2015-01-09 20:53:26 +03:00
|
|
|
continue
|
2015-03-03 12:35:11 +03:00
|
|
|
uas_corr += token.head.i == heads[i]
|
|
|
|
las_corr += token.head.i == heads[i] and token.dep_ == labels[i]
|
|
|
|
#print token.orth_, token.head.orth_, token.dep_, labels[i]
|
2015-01-09 20:53:26 +03:00
|
|
|
total += 1
|
2015-01-30 08:36:24 +03:00
|
|
|
print loss, skipped, (loss+skipped + total)
|
2015-02-02 15:02:48 +03:00
|
|
|
print pos_corr / n_tokens
|
2015-03-03 12:35:11 +03:00
|
|
|
print float(las_corr) / (total + loss)
|
|
|
|
return float(uas_corr) / (total + loss)
|
2015-01-09 20:53:26 +03:00
|
|
|
|
|
|
|
|
|
|
|
def main(train_loc, dev_loc, model_dir):
|
2015-02-22 04:06:29 +03:00
|
|
|
train(English, read_docparse_gold(train_loc), model_dir,
|
|
|
|
gold_preproc=False, force_gold=False)
|
|
|
|
print evaluate(English, read_docparse_gold(dev_loc), model_dir, gold_preproc=False)
|
2015-01-09 20:53:26 +03:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
plac.call(main)
|