* Work on sbd

This commit is contained in:
Matthew Honnibal 2015-01-29 03:18:29 +11:00
parent b08c0ce54e
commit f590382134
2 changed files with 75 additions and 27 deletions

View File

@ -52,13 +52,12 @@ cdef inline bint _can_break_shift(const State* s) nogil:
cdef int i
if not USE_BREAK:
return False
elif not _can_shift(s):
elif at_eol(s):
return False
else:
# P. 757
# In UPP, if Shift(F) or RightArc(F) fail to result in a single parsing
# tree, they cannot be performed as well.
seen_headless = False
for i in range(s.stack_len):
if s.sent[s.stack[i]].head == 0:
return False
@ -92,12 +91,18 @@ cdef int _shift_cost(const State* s, const int* gold) except -1:
cost += children_in_stack(s, s.i, gold)
if NON_MONOTONIC:
cost += gold[s.stack[0]] == s.i
# If we can break, and there's no cost to doing so, we should
if _can_break_shift(s) and _break_shift_cost(s, gold) == 0:
cost += 1
return cost
cdef int _right_cost(const State* s, const int* gold) except -1:
assert s.stack_len >= 1
cost = 0
# If we can break, and there's no cost to doing so, we should
if _can_break_right(s) and _break_right_cost(s, gold) == 0:
cost += 1
if gold[s.i] == s.stack[0]:
return cost
cost += head_in_buffer(s, s.i, gold)
@ -130,24 +135,48 @@ cdef int _reduce_cost(const State* s, const int* gold) except -1:
cdef int _break_shift_cost(const State* s, const int* gold) except -1:
cdef int cost = _shift_cost(s, gold)
# When we break, we Reduce all of the words on the stack. So, the Break
# cost is the sum of the Reduce costs
for i in range(s.stack_len):
cost += children_in_buffer(s, s.stack[i], gold)
if NON_MONOTONIC:
cost += head_in_buffer(s, s.stack[i], gold)
# When we break, we Reduce all of the words on the stack. We also remove
# the first word from the buffer.
#
# n0_cost:
cdef int cost = 0
# number of head/child deps between n0 and N1...Nn
cost += children_in_buffer(s, s.i, gold)
cost += head_in_buffer(s, s.i, gold)
# Don't count self-deps
if gold[s.i] == s.i:
cost -= 2
# number of child deps from N0 into stack
cost += children_in_stack(s, s.i, gold)
# number of head deps to N0 from stack
cost += head_in_stack(s, s.i, gold)
# Number of deps between S0...Sn and N1...Nn
for i in range(s.i+1, s.sent_len):
cost += children_in_stack(s, i, gold)
cost += head_in_stack(s, i, gold)
return cost
cdef int _break_right_cost(const State* s, const int* gold) except -1:
cdef int cost = _right_cost(s, gold)
# When we break, we Reduce all of the words on the stack. So, the Break
# cost is the sum of the Reduce costs
for i in range(s.stack_len):
cost += children_in_buffer(s, s.stack[i], gold)
if NON_MONOTONIC:
cost += head_in_buffer(s, s.stack[i], gold)
cdef int cost = 0
assert s.stack_len >= 1
cdef int i
# When we break, we Reduce all of the words on the stack. We also remove
# the first word from the buffer.
#
# n0_cost:
# number of head/child deps between n0 and N0...Nn
cost += children_in_buffer(s, s.i, gold)
cost += head_in_buffer(s, s.i, gold)
# number of child deps from N0 into stack
cost += children_in_stack(s, s.i, gold)
# number of head deps to N0 from S1..Sn
for i in range(1, s.stack_len):
cost += s.stack[-i] == gold[s.i]
# Number of deps between S0...Sn and N1...Nn
for i in range(s.i+1, s.sent_len):
cost += children_in_stack(s, i, gold)
cost += head_in_stack(s, i, gold)
return cost
@ -213,14 +242,14 @@ cdef class TransitionSystem:
add_dep(s, s.stack[0], s.i, t.label)
push_stack(s)
elif t.move == REDUCE:
# TODO: Huh? Is this some weirdness from the non-monotonic?
add_dep(s, s.stack[-1], s.stack[0], get_s0(s).dep)
pop_stack(s)
elif t.move == BREAK_RIGHT:
add_dep(s, s.stack[0], s.i, t.label)
push_stack(s)
while s.stack_len != 0:
if not has_head(get_s0(s)):
get_s0(s).dep = 0
#add_dep(s, s.stack[-1], s.stack[0], get_s0(s).dep)
s.stack -= 1
s.stack_len -= 1
if not at_eol(s):
@ -228,8 +257,9 @@ cdef class TransitionSystem:
elif t.move == BREAK_SHIFT:
push_stack(s)
get_s0(s).dep = 0
s.stack -= s.stack_len
s.stack_len = 0
while s.stack_len != 0:
s.stack -= 1
s.stack_len -= 1
if not at_eol(s):
push_stack(s)
else:
@ -289,10 +319,11 @@ cdef class TransitionSystem:
elif gold_heads[s.i] == s.stack[0]:
target_label = gold_labels[s.i]
if guess.move == RIGHT or guess.move == BREAK_RIGHT:
guess.cost += guess.label != target_label
if unl_costs[guess.move] != 0:
guess.cost += guess.label != target_label
for i in range(self.n_moves):
t = self._moves[i]
if (t.move == RIGHT or t.move == BREAK_RIGHT) and t.label == target_label:
if t.label == target_label and unl_costs[t.move] == 0:
return t
cdef int best = -1

View File

@ -41,11 +41,12 @@ def set_debug(val):
cdef unicode print_state(State* s, list words):
words = list(words) + ['EOL']
top = words[s.stack[0]]
second = words[s.stack[-1]]
top = words[s.stack[0]] + '_%d' % s.sent[s.stack[0]].head
second = words[s.stack[-1]] + '_%d' % s.sent[s.stack[-1]].head
third = words[s.stack[-2]] + '_%d' % s.sent[s.stack[-2]].head
n0 = words[s.i]
n1 = words[s.i + 1]
return ' '.join((second, top, '|', n0, n1))
return ' '.join((str(s.stack_len), third, second, top, '|', n0, n1))
def get_templates(name):
@ -86,7 +87,8 @@ cdef class GreedyParser:
tokens.is_parsed = True
return 0
def train_sent(self, Tokens tokens, list gold_heads, list gold_labels):
def train_sent(self, Tokens tokens, list gold_heads, list gold_labels,
force_gold=False):
cdef:
const Feature* feats
const weight_t* scores
@ -104,15 +106,30 @@ cdef class GreedyParser:
labels_array[i] = self.moves.label_ids[gold_labels[i]]
py_words = [t.orth_ for t in tokens]
py_moves = ['S', 'D', 'L', 'R', 'BS', 'BR']
history = []
#print py_words
cdef State* state = init_state(mem, tokens.data, tokens.length)
while not is_final(state):
fill_context(context, state)
scores = self.model.score(context)
guess = self.moves.best_valid(scores, state)
best = self.moves.best_gold(&guess, scores, state, heads_array, labels_array)
history.append((py_moves[best.move], print_state(state, py_words)))
self.model.update(context, guess.clas, best.clas, guess.cost)
self.moves.transition(state, &guess)
if force_gold:
self.moves.transition(state, &best)
else:
self.moves.transition(state, &guess)
cdef int n_corr = 0
for i in range(tokens.length):
n_corr += (i + state.sent[i].head) == gold_heads[i]
if force_gold and n_corr != tokens.length:
print py_words
print gold_heads
for move, state_str in history:
print move, state_str
for i in range(tokens.length):
print py_words[i], py_words[i + state.sent[i].head], py_words[gold_heads[i]]
raise Exception
return n_corr