Merge branch 'punctparse'

This commit is contained in:
Matthew Honnibal 2015-01-30 16:38:56 +11:00
commit b3f9b199cf
6 changed files with 226 additions and 58 deletions

View File

@ -9,6 +9,7 @@ import codecs
import random
import time
import gzip
import nltk
import plac
import cProfile
@ -22,10 +23,15 @@ from spacy.syntax.parser import GreedyParser
from spacy.syntax.util import Config
def is_punct_label(label):
return label == 'P' or label.lower() == 'punct'
def read_tokenized_gold(file_):
"""Read a standard CoNLL/MALT-style format"""
sents = []
for sent_str in file_.read().strip().split('\n\n'):
ids = []
words = []
heads = []
labels = []
@ -35,50 +41,137 @@ def read_tokenized_gold(file_):
words.append(word)
if head_idx == -1:
head_idx = i
ids.append(id_)
heads.append(head_idx)
labels.append(label)
tags.append(pos_string)
sents.append((words, heads, labels, tags))
sents.append((ids_, words, heads, labels, tags))
return sents
def read_docparse_gold(file_):
sents = []
paragraphs = []
for sent_str in file_.read().strip().split('\n\n'):
words = []
heads = []
labels = []
tags = []
ids = []
lines = sent_str.strip().split('\n')
raw_text = lines[0]
tok_text = lines[1]
for i, line in enumerate(lines[2:]):
word, pos_string, head_idx, label = _parse_line(line)
id_, word, pos_string, head_idx, label = _parse_line(line)
if label == 'root':
label = 'ROOT'
words.append(word)
if head_idx == -1:
head_idx = i
if head_idx < 0:
head_idx = id_
ids.append(id_)
heads.append(head_idx)
labels.append(label)
tags.append(pos_string)
words = tok_text.replace('<SEP>', ' ').replace('<SENT>', ' ').split(' ')
sents.append((words, heads, labels, tags))
return sents
tokenized = [sent_str.replace('<SEP>', ' ').split(' ')
for sent_str in tok_text.split('<SENT>')]
paragraphs.append((raw_text, tokenized, ids, words, tags, heads, labels))
return paragraphs
def _map_indices_to_tokens(ids, heads):
mapped = []
for head in heads:
if head not in ids:
mapped.append(None)
else:
mapped.append(ids.index(head))
return mapped
def _parse_line(line):
pieces = line.split()
if len(pieces) == 4:
return pieces[0], pieces[1], int(pieces[2]) - 1, pieces[3]
return 0, pieces[0], pieces[1], int(pieces[2]) - 1, pieces[3]
else:
id_ = int(pieces[0])
word = pieces[1]
pos = pieces[3]
head_idx = int(pieces[6]) - 1
head_idx = int(pieces[6])
label = pieces[7]
return word, pos, head_idx, label
return id_, word, pos, head_idx, label
loss = 0
def _align_annotations_to_non_gold_tokens(tokens, words, annot):
global loss
tags = []
heads = []
labels = []
orig_words = list(words)
missed = []
for token in tokens:
while annot and token.idx > annot[0][0]:
miss_id, miss_tag, miss_head, miss_label = annot.pop(0)
miss_w = words.pop(0)
if not is_punct_label(miss_label):
missed.append(miss_w)
loss += 1
if not annot:
tags.append(None)
heads.append(None)
labels.append(None)
continue
id_, tag, head, label = annot[0]
if token.idx == id_:
tags.append(tag)
heads.append(head)
labels.append(label)
annot.pop(0)
words.pop(0)
elif token.idx < id_:
tags.append(None)
heads.append(None)
labels.append(None)
else:
raise StandardError
#if missed:
# print orig_words
# print missed
# for t in tokens:
# print t.idx, t.orth_
return loss, tags, heads, labels
def iter_data(paragraphs, tokenizer, gold_preproc=False):
for raw, tokenized, ids, words, tags, heads, labels in paragraphs:
if not gold_preproc:
tokens = tokenizer(raw)
loss, tags, heads, labels = _align_annotations_to_non_gold_tokens(
tokens, list(words),
zip(ids, tags, heads, labels))
ids = [t.idx for t in tokens]
heads = _map_indices_to_tokens(ids, heads)
yield tokens, tags, heads, labels
else:
assert len(words) == len(heads)
for words in tokenized:
sent_ids = ids[:len(words)]
sent_tags = tags[:len(words)]
sent_heads = heads[:len(words)]
sent_labels = labels[:len(words)]
sent_heads = _map_indices_to_tokens(sent_ids, sent_heads)
tokens = tokenizer.tokens_from_list(words)
yield tokens, sent_tags, sent_heads, sent_labels
ids = ids[len(words):]
tags = tags[len(words):]
heads = heads[len(words):]
labels = labels[len(words):]
def get_labels(sents):
left_labels = set()
right_labels = set()
for _, heads, labels, _ in sents:
for raw, tokenized, ids, words, tags, heads, labels in sents:
for child, (head, label) in enumerate(zip(heads, labels)):
if head > child:
left_labels.add(label)
@ -87,7 +180,8 @@ def get_labels(sents):
return list(sorted(left_labels)), list(sorted(right_labels))
def train(Language, sents, model_dir, n_iter=15, feat_set=u'basic', seed=0):
def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
gold_preproc=False):
dep_model_dir = path.join(model_dir, 'deps')
pos_model_dir = path.join(model_dir, 'pos')
if path.exists(dep_model_dir):
@ -99,7 +193,7 @@ def train(Language, sents, model_dir, n_iter=15, feat_set=u'basic', seed=0):
setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES,
pos_model_dir)
left_labels, right_labels = get_labels(sents)
left_labels, right_labels = get_labels(paragraphs)
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
left_labels=left_labels, right_labels=right_labels)
@ -109,58 +203,50 @@ def train(Language, sents, model_dir, n_iter=15, feat_set=u'basic', seed=0):
heads_corr = 0
pos_corr = 0
n_tokens = 0
for words, heads, labels, tags in sents:
tags = [nlp.tagger.tag_names.index(tag) for tag in tags]
tokens = nlp.tokenizer.tokens_from_list(words)
for tokens, tag_strs, heads, labels in iter_data(paragraphs, nlp.tokenizer,
gold_preproc=gold_preproc):
nlp.tagger(tokens)
heads_corr += nlp.parser.train_sent(tokens, heads, labels)
pos_corr += nlp.tagger.train(tokens, tags)
heads_corr += nlp.parser.train_sent(tokens, heads, labels, force_gold=False)
pos_corr += nlp.tagger.train(tokens, tag_strs)
n_tokens += len(tokens)
acc = float(heads_corr) / n_tokens
pos_acc = float(pos_corr) / n_tokens
print '%d: ' % itn, '%.3f' % acc, '%.3f' % pos_acc
random.shuffle(sents)
random.shuffle(paragraphs)
nlp.parser.model.end_training()
nlp.tagger.model.end_training()
#nlp.parser.model.dump(path.join(dep_model_dir, 'model'), freq_thresh=0)
return acc
def evaluate(Language, dev_loc, model_dir):
def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
nlp = Language()
n_corr = 0
total = 0
skipped = 0
with codecs.open(dev_loc, 'r', 'utf8') as file_:
sents = read_tokenized_gold(file_)
for words, heads, labels, tags in sents:
tokens = nlp.tokenizer.tokens_from_list(words)
paragraphs = read_docparse_gold(file_)
for tokens, tag_strs, heads, labels in iter_data(paragraphs, nlp.tokenizer,
gold_preproc=gold_preproc):
assert len(tokens) == len(labels)
nlp.tagger(tokens)
nlp.parser(tokens)
for i, token in enumerate(tokens):
#print i, token.string, i + token.head, heads[i], labels[i]
if labels[i] == 'P' or labels[i] == 'punct':
if heads[i] is None:
skipped += 1
continue
if is_punct_label(labels[i]):
continue
n_corr += token.head.i == heads[i]
total += 1
return float(n_corr) / total
PROFILE = False
print loss, skipped, (loss+skipped + total)
return float(n_corr) / (total + loss)
def main(train_loc, dev_loc, model_dir):
with codecs.open(train_loc, 'r', 'utf8') as file_:
train_sents = read_tokenized_gold(file_)
if PROFILE:
import cProfile
import pstats
cmd = "train(EN, train_sents, tag_names, model_dir, n_iter=2)"
cProfile.runctx(cmd, globals(), locals(), "Profile.prof")
s = pstats.Stats("Profile.prof")
s.strip_dirs().sort_stats("time").print_stats()
else:
train(English, train_sents, model_dir)
print evaluate(English, dev_loc, model_dir)
train_sents = read_docparse_gold(file_)
#train(English, train_sents, model_dir, gold_preproc=False)
print evaluate(English, dev_loc, model_dir, gold_preproc=False)
if __name__ == '__main__':

View File

@ -255,19 +255,23 @@ cdef class EnPosTagger:
tokens._tag_strings = self.tag_names
tokens.is_tagged = True
def train(self, Tokens tokens, object golds):
def train(self, Tokens tokens, object gold_tag_strs):
cdef int i
cdef int loss
cdef atom_t[N_CONTEXT_FIELDS] context
cdef const weight_t* scores
golds = [self.tag_names.index(g) if g is not None else -1
for g in gold_tag_strs]
correct = 0
for i in range(tokens.length):
fill_context(context, i, tokens.data)
scores = self.model.score(context)
guess = arg_max(scores, self.model.n_classes)
self.model.update(context, guess, golds[i], guess != golds[i])
loss = guess != golds[i] if golds[i] != -1 else 0
self.model.update(context, guess, golds[i], loss)
tokens.data[i].tag = guess
self.set_morph(i, tokens.data)
correct += guess == golds[i]
correct += loss == 0
return correct
cdef int set_morph(self, const int i, TokenC* tokens) except -1:

View File

@ -8,7 +8,9 @@ from ._state cimport head_in_stack, children_in_stack
from ..structs cimport TokenC
DEF NON_MONOTONIC = True
DEF USE_BREAK = True
cdef enum:
@ -16,8 +18,12 @@ cdef enum:
REDUCE
LEFT
RIGHT
BREAK
N_MOVES
# Break transition from here
# http://www.aclweb.org/anthology/P13-1074
cdef inline bint _can_shift(const State* s) nogil:
return not at_eol(s)
@ -41,6 +47,24 @@ cdef inline bint _can_reduce(const State* s) nogil:
return s.stack_len >= 2 and has_head(get_s0(s))
cdef inline bint _can_break(const State* s) nogil:
cdef int i
if not USE_BREAK:
return False
elif at_eol(s):
return False
else:
# If stack is disconnected, cannot break
seen_headless = False
for i in range(s.stack_len):
if s.sent[s.stack[-i]].head == 0:
if seen_headless:
return False
else:
seen_headless = True
return True
cdef int _shift_cost(const State* s, const int* gold) except -1:
assert not at_eol(s)
cost = 0
@ -48,6 +72,9 @@ cdef int _shift_cost(const State* s, const int* gold) except -1:
cost += children_in_stack(s, s.i, gold)
if NON_MONOTONIC:
cost += gold[s.stack[0]] == s.i
# If we can break, and there's no cost to doing so, we should
if _can_break(s) and _break_cost(s, gold) == 0:
cost += 1
return cost
@ -74,6 +101,7 @@ cdef int _left_cost(const State* s, const int* gold) except -1:
cost += children_in_buffer(s, s.stack[0], gold)
if NON_MONOTONIC and s.stack_len >= 2:
cost += gold[s.stack[0]] == s.stack[-1]
cost += gold[s.stack[0]] == s.stack[0]
return cost
@ -85,6 +113,16 @@ cdef int _reduce_cost(const State* s, const int* gold) except -1:
return cost
cdef int _break_cost(const State* s, const int* gold) except -1:
# When we break, we Reduce all of the words on the stack.
cdef int cost = 0
# Number of deps between S0...Sn and N0...Nn
for i in range(s.i, s.sent_len):
cost += children_in_stack(s, i, gold)
cost += head_in_stack(s, i, gold)
return cost
cdef class TransitionSystem:
def __init__(self, list left_labels, list right_labels):
self.mem = Pool()
@ -94,7 +132,7 @@ cdef class TransitionSystem:
right_labels.pop(right_labels.index('ROOT'))
if 'ROOT' in left_labels:
left_labels.pop(left_labels.index('ROOT'))
self.n_moves = 2 + len(left_labels) + len(right_labels)
self.n_moves = 3 + len(left_labels) + len(right_labels)
moves = <Transition*>self.mem.alloc(self.n_moves, sizeof(Transition))
cdef int i = 0
moves[i].move = SHIFT
@ -121,6 +159,10 @@ cdef class TransitionSystem:
moves[i].label = label_id
moves[i].clas = i
i += 1
moves[i].move = BREAK
moves[i].label = 0
moves[i].clas = i
i += 1
self._moves = moves
cdef int transition(self, State *s, const Transition* t) except -1:
@ -136,8 +178,17 @@ cdef class TransitionSystem:
add_dep(s, s.stack[0], s.i, t.label)
push_stack(s)
elif t.move == REDUCE:
# TODO: Huh? Is this some weirdness from the non-monotonic?
add_dep(s, s.stack[-1], s.stack[0], get_s0(s).dep)
pop_stack(s)
elif t.move == BREAK:
while s.stack_len != 0:
if get_s0(s).head == 0:
get_s0(s).dep = 0
s.stack -= 1
s.stack_len -= 1
if not at_eol(s):
push_stack(s)
else:
raise Exception(t.move)
@ -147,6 +198,7 @@ cdef class TransitionSystem:
valid[LEFT] = _can_left(s)
valid[RIGHT] = _can_right(s)
valid[REDUCE] = _can_reduce(s)
valid[BREAK] = _can_break(s)
cdef int best = -1
cdef weight_t score = 0
@ -175,6 +227,7 @@ cdef class TransitionSystem:
unl_costs[LEFT] = _left_cost(s, gold_heads) if _can_left(s) else -1
unl_costs[RIGHT] = _right_cost(s, gold_heads) if _can_right(s) else -1
unl_costs[REDUCE] = _reduce_cost(s, gold_heads) if _can_reduce(s) else -1
unl_costs[BREAK] = _break_cost(s, gold_heads) if _can_break(s) else -1
guess.cost = unl_costs[guess.move]
cdef Transition t
@ -191,10 +244,11 @@ cdef class TransitionSystem:
elif gold_heads[s.i] == s.stack[0]:
target_label = gold_labels[s.i]
if guess.move == RIGHT:
guess.cost += guess.label != target_label
if unl_costs[guess.move] != 0:
guess.cost += guess.label != target_label
for i in range(self.n_moves):
t = self._moves[i]
if t.move == RIGHT and t.label == target_label:
if t.label == target_label and unl_costs[t.move] == 0:
return t
cdef int best = -1

View File

@ -41,11 +41,12 @@ def set_debug(val):
cdef unicode print_state(State* s, list words):
words = list(words) + ['EOL']
top = words[s.stack[0]]
second = words[s.stack[-1]]
top = words[s.stack[0]] + '_%d' % s.sent[s.stack[0]].head
second = words[s.stack[-1]] + '_%d' % s.sent[s.stack[-1]].head
third = words[s.stack[-2]] + '_%d' % s.sent[s.stack[-2]].head
n0 = words[s.i]
n1 = words[s.i + 1]
return ' '.join((second, top, '|', n0, n1))
return ' '.join((str(s.stack_len), third, second, top, '|', n0, n1))
def get_templates(name):
@ -86,7 +87,8 @@ cdef class GreedyParser:
tokens.is_parsed = True
return 0
def train_sent(self, Tokens tokens, list gold_heads, list gold_labels):
def train_sent(self, Tokens tokens, list gold_heads, list gold_labels,
force_gold=False):
cdef:
const Feature* feats
const weight_t* scores
@ -100,19 +102,39 @@ cdef class GreedyParser:
cdef int* labels_array = <int*>mem.alloc(tokens.length, sizeof(int))
cdef int i
for i in range(tokens.length):
heads_array[i] = gold_heads[i]
labels_array[i] = self.moves.label_ids[gold_labels[i]]
if gold_heads[i] is None:
heads_array[i] = -1
labels_array[i] = -1
else:
heads_array[i] = gold_heads[i]
labels_array[i] = self.moves.label_ids[gold_labels[i]]
py_words = [t.orth_ for t in tokens]
py_moves = ['S', 'D', 'L', 'R', 'BS', 'BR']
history = []
#print py_words
cdef State* state = init_state(mem, tokens.data, tokens.length)
while not is_final(state):
fill_context(context, state)
scores = self.model.score(context)
guess = self.moves.best_valid(scores, state)
best = self.moves.best_gold(&guess, scores, state, heads_array, labels_array)
history.append((py_moves[best.move], print_state(state, py_words)))
self.model.update(context, guess.clas, best.clas, guess.cost)
self.moves.transition(state, &guess)
if force_gold:
self.moves.transition(state, &best)
else:
self.moves.transition(state, &guess)
cdef int n_corr = 0
for i in range(tokens.length):
if gold_heads[i] != -1:
n_corr += (i + state.sent[i].head) == gold_heads[i]
if force_gold and n_corr != tokens.length:
print py_words
print gold_heads
for move, state_str in history:
print move, state_str
for i in range(tokens.length):
print py_words[i], py_words[i + state.sent[i].head], py_words[gold_heads[i]]
raise Exception
return n_corr

View File

@ -95,7 +95,6 @@ cdef class Tokenizer:
return tokens
cdef int _try_cache(self, int idx, hash_t key, Tokens tokens) except -1:
#cached = <Cached*>self._specials.get(key)
cached = <_Cached*>self._cache.get(key)
if cached == NULL:
return False
@ -176,7 +175,10 @@ cdef class Tokenizer:
if string.n != 0:
cache_hit = self._try_cache(idx, string.key, tokens)
if cache_hit:
idx = tokens.data[tokens.length - 1].idx + 1
# Get last idx
idx = tokens.data[tokens.length - 1].idx
# Increment by last length
idx += tokens.data[tokens.length - 1].lex.length
else:
split = self._find_infix(string.chars, string.n)
if split == 0 or split == -1:

View File

@ -8,7 +8,7 @@ from spacy.en import English
@pytest.fixture
def EN():
return English()
return English().tokenizer
def test_single_word(EN):
tokens = EN(u'hello')