* Refactor arc_eager to use new TransitionSystem base class. Need to fix oracle

This commit is contained in:
Matthew Honnibal 2015-02-21 11:06:37 -05:00
parent b063001596
commit 8eadb984cb
4 changed files with 158 additions and 269 deletions

View File

@ -4,25 +4,8 @@ from thinc.typedefs cimport weight_t
from ._state cimport State from ._state cimport State
from .transition_system cimport TransitionSystem, Transition
cdef struct Transition: cdef class ArcEager(TransitionSystem):
int clas pass
int move
int label
int cost
weight_t score
cdef class TransitionSystem:
cdef Pool mem
cdef readonly int n_moves
cdef dict label_ids
cdef const Transition* _moves
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *
cdef Transition best_gold(self, Transition* guess, const weight_t* scores,
const State* s,
const int* gold_heads, const int* gold_labels) except *
cdef int transition(self, State *s, const Transition* t) except -1

View File

@ -8,11 +8,17 @@ from ._state cimport head_in_stack, children_in_stack
from ..structs cimport TokenC from ..structs cimport TokenC
from .transition_system cimport do_func_t, get_cost_func_t
from .conll cimport GoldParse
DEF NON_MONOTONIC = True DEF NON_MONOTONIC = True
DEF USE_BREAK = True DEF USE_BREAK = True
cdef weight_t MIN_SCORE = -90000
# Break transition from here
# http://www.aclweb.org/anthology/P13-1074
cdef enum: cdef enum:
SHIFT SHIFT
REDUCE REDUCE
@ -21,8 +27,150 @@ cdef enum:
BREAK BREAK
N_MOVES N_MOVES
# Break transition from here
# http://www.aclweb.org/anthology/P13-1074 cdef do_func_t[N_MOVES] do_funcs
cdef get_cost_func_t[N_MOVES] get_cost_funcs
cdef class ArcEager(TransitionSystem):
cdef Transition init_transition(self, int clas, int move, int label) except *:
return Transition(
score=0,
clas=i,
move=move,
label=label,
do=do_funcs[move],
get_cost=get_cost_funcs[move])
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
cdef bint[N_MOVES] is_valid
is_valid[SHIFT] = _can_shift(s)
is_valid[REDUCE] = _can_reduce(s)
is_valid[LEFT] = _can_left(s)
is_valid[RIGHT] = _can_right(s)
is_valid[BREAK] = _can_break(s)
cdef Transition best
cdef weight_t score = MIN_SCORE
cdef int i
for i in range(self.n_moves):
if scores[i] > score and is_valid[self.c[i].move]:
best = self.c[i]
score = scores[i]
# Label Shift moves with the best Right-Arc label, for non-monotonic
# actions
if best.move == SHIFT:
score = MIN_SCORE
for i in range(self.n_moves):
if self.c[i].move == RIGHT and scores[i] > score:
best.label = self.c[i].label
score = scores[i]
return best
cdef int _do_shift(const Transition* self, State* state) except -1:
# Set the dep label, in case we need it after we reduce
if NON_MONOTONIC:
get_s0(state).dep = self.label
push_stack(state)
cdef int _do_left(const Transition* self, State* state) except -1:
add_dep(state, state.i, state.stack[0], self.label)
pop_stack(state)
cdef int _do_right(const Transition* self, State* state) except -1:
add_dep(state, state.stack[0], state.i, self.label)
push_stack(state)
cdef int _do_reduce(const Transition* self, State* state) except -1:
# TODO: Huh? Is this some weirdness from the non-monotonic?
add_dep(state, state.stack[-1], state.stack[0], get_s0(state).dep)
pop_stack(state)
cdef int _do_break(const Transition* self, State* state) except -1:
state.sent[state.i-1].sent_end = True
while state.stack_len != 0:
if get_s0(state).head == 0:
get_s0(state).dep = 0
state.stack -= 1
state.stack_len -= 1
if not at_eol(state):
push_stack(state)
do_funcs[SHIFT] = _do_shift
do_funcs[REDUCE] = _do_reduce
do_funcs[LEFT] = _do_left
do_funcs[RIGHT] = _do_right
do_funcs[BREAK] = _do_break
cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) except -1:
assert not at_eol(s)
cost = 0
cost += head_in_stack(s, s.i, gold.heads)
cost += children_in_stack(s, s.i, gold.heads)
if NON_MONOTONIC:
cost += gold[s.stack[0]] == s.i
# If we can break, and there's no cost to doing so, we should
if _can_break(s) and _break_cost(self, s, gold) == 0:
cost += 1
return cost
cdef int _right_cost(const Transition* self, const State* s, GoldParse gold) except -1:
assert s.stack_len >= 1
cost = 0
if gold[s.i] == s.stack[0]:
return cost
cost += head_in_buffer(s, s.i, gold.heads)
cost += children_in_stack(s, s.i, gold.heads)
cost += head_in_stack(s, s.i, gold.heads)
if NON_MONOTONIC:
cost += gold[s.stack[0]] == s.i
return cost
cdef int _left_cost(const Transition* self, const State* s, GoldParse gold) except -1:
assert s.stack_len >= 1
cost = 0
if gold[s.stack[0]] == s.i:
return cost
cost += head_in_buffer(s, s.stack[0], gold.heads)
cost += children_in_buffer(s, s.stack[0], gold.heads)
if NON_MONOTONIC and s.stack_len >= 2:
cost += gold[s.stack[0]] == s.stack[-1]
cost += gold[s.stack[0]] == s.stack[0]
return cost
cdef int _reduce_cost(const Transition* self, const State* s, GoldParse gold) except -1:
cdef int cost = 0
cost += children_in_buffer(s, s.stack[0], gold.heads)
if NON_MONOTONIC:
cost += head_in_buffer(s, s.stack[0], gold.heads)
return cost
cdef int _break_cost(const Transition* self, const State* s, GoldParse gold) except -1:
# When we break, we Reduce all of the words on the stack.
cdef int cost = 0
# Number of deps between S0...Sn and N0...Nn
for i in range(s.i, s.sent_len):
cost += children_in_stack(s, i, gold.heads)
cost += head_in_stack(s, i, gold.heads)
return cost
get_cost_funcs[SHIFT] = _shift_cost
get_cost_funcs[REDUCE] = _reduce_cost
get_cost_funcs[LEFT] = _left_cost
get_cost_funcs[RIGHT] = _right_cost
get_cost_funcs[BREAK] = _break_cost
cdef inline bint _can_shift(const State* s) nogil: cdef inline bint _can_shift(const State* s) nogil:
@ -63,224 +211,3 @@ cdef inline bint _can_break(const State* s) nogil:
else: else:
seen_headless = True seen_headless = True
return True return True
cdef int _shift_cost(const State* s, const int* gold) except -1:
assert not at_eol(s)
cost = 0
cost += head_in_stack(s, s.i, gold)
cost += children_in_stack(s, s.i, gold)
if NON_MONOTONIC:
cost += gold[s.stack[0]] == s.i
# If we can break, and there's no cost to doing so, we should
if _can_break(s) and _break_cost(s, gold) == 0:
cost += 1
return cost
cdef int _right_cost(const State* s, const int* gold) except -1:
assert s.stack_len >= 1
cost = 0
if gold[s.i] == s.stack[0]:
return cost
cost += head_in_buffer(s, s.i, gold)
cost += children_in_stack(s, s.i, gold)
cost += head_in_stack(s, s.i, gold)
if NON_MONOTONIC:
cost += gold[s.stack[0]] == s.i
return cost
cdef int _left_cost(const State* s, const int* gold) except -1:
assert s.stack_len >= 1
cost = 0
if gold[s.stack[0]] == s.i:
return cost
cost += head_in_buffer(s, s.stack[0], gold)
cost += children_in_buffer(s, s.stack[0], gold)
if NON_MONOTONIC and s.stack_len >= 2:
cost += gold[s.stack[0]] == s.stack[-1]
cost += gold[s.stack[0]] == s.stack[0]
return cost
cdef int _reduce_cost(const State* s, const int* gold) except -1:
cdef int cost = 0
cost += children_in_buffer(s, s.stack[0], gold)
if NON_MONOTONIC:
cost += head_in_buffer(s, s.stack[0], gold)
return cost
cdef int _break_cost(const State* s, const int* gold) except -1:
# When we break, we Reduce all of the words on the stack.
cdef int cost = 0
# Number of deps between S0...Sn and N0...Nn
for i in range(s.i, s.sent_len):
cost += children_in_stack(s, i, gold)
cost += head_in_stack(s, i, gold)
return cost
cdef class ArcEager(TransitionSystem):
def __init__(self, list left_labels, list right_labels):
self.mem = Pool()
left_labels.sort()
right_labels.sort()
if 'ROOT' in right_labels:
right_labels.pop(right_labels.index('ROOT'))
if 'ROOT' in left_labels:
left_labels.pop(left_labels.index('ROOT'))
self.n_moves = 3 + len(left_labels) + len(right_labels)
moves = <Transition*>self.mem.alloc(self.n_moves, sizeof(Transition))
cdef int i = 0
moves[i].move = SHIFT
moves[i].label = 0
moves[i].clas = i
i += 1
moves[i].move = REDUCE
moves[i].label = 0
moves[i].clas = i
i += 1
self.label_ids = {'ROOT': 0}
cdef int label_id
for label_str in left_labels:
label_str = unicode(label_str)
label_id = self.label_ids.setdefault(label_str, len(self.label_ids))
moves[i].move = LEFT
moves[i].label = label_id
moves[i].clas = i
i += 1
for label_str in right_labels:
label_str = unicode(label_str)
label_id = self.label_ids.setdefault(label_str, len(self.label_ids))
moves[i].move = RIGHT
moves[i].label = label_id
moves[i].clas = i
i += 1
moves[i].move = BREAK
moves[i].label = 0
moves[i].clas = i
i += 1
self.c = moves
cdef int transition(self, State *s, const Transition* t) except -1:
if t.move == SHIFT:
# Set the dep label, in case we need it after we reduce
if NON_MONOTONIC:
s.sent[s.i].dep = t.label
push_stack(s)
elif t.move == LEFT:
add_dep(s, s.i, s.stack[0], t.label)
pop_stack(s)
elif t.move == RIGHT:
add_dep(s, s.stack[0], s.i, t.label)
push_stack(s)
elif t.move == REDUCE:
# TODO: Huh? Is this some weirdness from the non-monotonic?
add_dep(s, s.stack[-1], s.stack[0], get_s0(s).dep)
pop_stack(s)
elif t.move == BREAK:
s.sent[s.i-1].sent_end = True
while s.stack_len != 0:
if get_s0(s).head == 0:
get_s0(s).dep = 0
s.stack -= 1
s.stack_len -= 1
if not at_eol(s):
push_stack(s)
else:
raise Exception(t.move)
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
cdef bint[N_MOVES] valid
valid[SHIFT] = _can_shift(s)
valid[LEFT] = _can_left(s)
valid[RIGHT] = _can_right(s)
valid[REDUCE] = _can_reduce(s)
valid[BREAK] = _can_break(s)
cdef int best = -1
cdef weight_t score = 0
cdef weight_t best_r_score = -9000
cdef int best_r_label = -1
cdef int i
for i in range(self.n_moves):
if valid[self._moves[i].move] and (best == -1 or scores[i] > score):
best = i
score = scores[i]
if self._moves[i].move == RIGHT and scores[i] > best_r_score:
best_r_label = self._moves[i].label
assert best >= 0
cdef Transition t = self._moves[best]
t.score = score
if t.move == SHIFT:
t.label = best_r_label
return t
cdef Transition best_gold(self, Transition* guess, const weight_t* scores,
const State* s,
const int* gold_heads, const int* gold_labels) except *:
# If we can create a gold dependency, only one action can be correct
cdef int[N_MOVES] unl_costs
unl_costs[SHIFT] = _shift_cost(s, gold_heads) if _can_shift(s) else -1
unl_costs[LEFT] = _left_cost(s, gold_heads) if _can_left(s) else -1
unl_costs[RIGHT] = _right_cost(s, gold_heads) if _can_right(s) else -1
unl_costs[REDUCE] = _reduce_cost(s, gold_heads) if _can_reduce(s) else -1
unl_costs[BREAK] = _break_cost(s, gold_heads) if _can_break(s) else -1
guess.cost = unl_costs[guess.move]
cdef Transition t
cdef int target_label
cdef int i
if gold_heads[s.stack[0]] == s.i:
target_label = gold_labels[s.stack[0]]
if guess.move == LEFT:
guess.cost += guess.label != target_label
for i in range(self.n_moves):
t = self._moves[i]
if t.move == LEFT and t.label == target_label:
return t
elif gold_heads[s.i] == s.stack[0]:
target_label = gold_labels[s.i]
if guess.move == RIGHT:
if unl_costs[guess.move] != 0:
guess.cost += guess.label != target_label
for i in range(self.n_moves):
t = self._moves[i]
if t.label == target_label and unl_costs[t.move] == 0:
return t
cdef int best = -1
cdef weight_t score = -9000
for i in range(self.n_moves):
t = self._moves[i]
if unl_costs[t.move] == 0 and (best == -1 or scores[i] > score):
best = i
score = scores[i]
t = self._moves[best]
t.score = score
if best < 0:
msg = ("No gold move found for configuration.\n"
"Is the gold-standard parse a projective tree?\n"
"S unl cost: %d\n"
"D unl cost: %d\n"
"L unl cost: %d\n"
"R unl cost: %d\n"
"S0, S0 gold: %d, %d\n"
"N0, N0 gold: %d, %d\n"
)
fields = [unl_costs[SHIFT], unl_costs[REDUCE], unl_costs[LEFT],
unl_costs[RIGHT],
s.stack[0], gold_heads[s.stack[0]],
s.i, gold_heads[s.i]]
raise OracleError(msg % tuple(fields))
return t
class OracleError(Exception):
pass

View File

@ -3,7 +3,7 @@ from thinc.typedefs cimport weight_t
from ..structs cimport TokenC from ..structs cimport TokenC
from ._state cimport State from ._state cimport State
from .conll cimport GoldParse
cdef struct Transition: cdef struct Transition:
@ -12,19 +12,13 @@ cdef struct Transition:
int label int label
weight_t score weight_t score
int cost
int (*get_cost)(const Transition* self, const State* state, const TokenC* gold) except -1
int (*is_valid)(const Transition* self, const State* state) except -1
int (*get_cost)(const Transition* self, const State* state, GoldParse gold) except -1
int (*do)(const Transition* self, State* state) except -1 int (*do)(const Transition* self, State* state) except -1
ctypedef int (*get_cost_func_t)(const Transition* self, const State* state, ctypedef int (*get_cost_func_t)(const Transition* self, const State* state,
const TokenC* gold) except -1 GoldParse gold) except -1
ctypedef int (*is_valid_func_t)(const Transition* self, const State* state) except -1
ctypedef int (*do_func_t)(const Transition* self, State* state) except -1 ctypedef int (*do_func_t)(const Transition* self, State* state) except -1
@ -39,4 +33,4 @@ cdef class TransitionSystem:
cdef const Transition best_valid(self, const weight_t*, const State*) except * cdef const Transition best_valid(self, const weight_t*, const State*) except *
cdef const Transition best_gold(self, const weight_t*, const State*, cdef const Transition best_gold(self, const weight_t*, const State*,
const TokenC*) except * GoldParse gold) except *

View File

@ -25,22 +25,7 @@ cdef class TransitionSystem:
raise NotImplementedError raise NotImplementedError
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *: cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
cdef Transition best raise NotImplementedError
cdef weight_t score = MIN_SCORE
cdef int i
for i in range(self.n_moves):
if scores[i] > score and self.c[i].is_valid(&self.c[i], s):
best = self.c[i]
score = scores[i]
# Label Shift moves with the best Right-Arc label, for non-monotonic
# actions
#if best.move == SHIFT:
# score = MIN_SCORE
# for i in range(self.n_moves):
# if self.c[i].move == RIGHT and scores[i] > score:
# best.label = self.c[i].label
# score = scores[i]
return best
cdef Transition best_gold(self, const weight_t* scores, const State* s, cdef Transition best_gold(self, const weight_t* scores, const State* s,
const TokenC* gold) except *: const TokenC* gold) except *: