Work on parser oracle

Update arc_eager oracle

Restore ArcEager.get_cost function

Update transition system
This commit is contained in:
Matthew Honnibal 2020-06-21 01:01:09 +02:00
parent 75a5f2d499
commit 5ca4c19ef2
2 changed files with 101 additions and 54 deletions

View File

@ -82,7 +82,7 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp
gold_i = cand_to_gold[cand_i] gold_i = cand_to_gold[cand_i]
if gold_i is not None: # Alignment found if gold_i is not None: # Alignment found
ref_tok = example.y.c[gold_i] ref_tok = example.y.c[gold_i]
gold_head = gold_to_cand[ref_tok.head + gold_i] gold_head = gold_to_cand[gold_i + ref_tok.head]
if gold_head is not None: if gold_head is not None:
gs.heads[cand_i] = gold_head gs.heads[cand_i] = gold_head
gs.labels[cand_i] = ref_tok.dep gs.labels[cand_i] = ref_tok.dep
@ -106,17 +106,17 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp
stack_words = set() stack_words = set()
for i in range(stcls.stack_depth()): for i in range(stcls.stack_depth()):
s_i = stcls.S(i) s_i = stcls.S(i)
head = s_i + gs.heads[s_i] head = gs.heads[s_i]
gs.n_kids_in_stack[head] += 1 gs.n_kids_in_stack[head] += 1
stack_words.add(s_i) stack_words.add(s_i)
buffer_words = set() buffer_words = set()
for i in range(stcls.buffer_length()): for i in range(stcls.buffer_length()):
b_i = stcls.B(i) b_i = stcls.B(i)
head = b_i + gs.heads[b_i] head = gs.heads[b_i]
gs.n_kids_in_buffer[head] += 1 gs.n_kids_in_buffer[head] += 1
buffer_words.add(b_i) buffer_words.add(b_i)
for i in range(gs.length): for i in range(gs.length):
head = i + gs.heads[i] head = gs.heads[i]
if head in stack_words: if head in stack_words:
gs.state_bits[i] = set_state_flag( gs.state_bits[i] = set_state_flag(
gs.state_bits[i], gs.state_bits[i],
@ -142,6 +142,58 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp
return gs return gs
cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) except *:
for i in range(gs.length):
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
HEAD_IN_BUFFER,
0
)
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
HEAD_IN_STACK,
0
)
gs.n_kids_in_stack[i] = 0
gs.n_kids_in_buffer[i] = 0
stack_words = set()
for i in range(stcls.stack_depth()):
s_i = stcls.S(i)
head = gs.heads[s_i]
gs.n_kids_in_stack[head] += 1
stack_words.add(s_i)
buffer_words = set()
for i in range(stcls.buffer_length()):
b_i = stcls.B(i)
head = gs.heads[b_i]
gs.n_kids_in_buffer[head] += 1
buffer_words.add(b_i)
for i in range(gs.length):
head = gs.heads[i]
if head in stack_words:
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
HEAD_IN_STACK,
1
)
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
HEAD_IN_BUFFER,
0
)
elif head in buffer_words:
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
HEAD_IN_STACK,
0
)
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
HEAD_IN_BUFFER,
1
)
cdef class ArcEagerGold: cdef class ArcEagerGold:
cdef GoldParseStateC c cdef GoldParseStateC c
cdef Pool mem cdef Pool mem
@ -150,6 +202,9 @@ cdef class ArcEagerGold:
self.mem = Pool() self.mem = Pool()
self.c = create_gold_state(self.mem, stcls, example) self.c = create_gold_state(self.mem, stcls, example)
def update(self, StateClass stcls):
update_gold_state(&self.c, stcls)
cdef int check_state_gold(char state_bits, char flag) nogil: cdef int check_state_gold(char state_bits, char flag) nogil:
@ -183,7 +238,7 @@ cdef weight_t push_cost(StateClass stcls, const void* _gold, int target) nogil:
cdef weight_t cost = 0 cdef weight_t cost = 0
if is_head_in_stack(gold, target): if is_head_in_stack(gold, target):
cost += 1 cost += 1
cost += gold.n_kids_in_buffer[target] cost += gold.n_kids_in_stack[target]
if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0: if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0:
cost += 1 cost += 1
return cost return cost
@ -319,22 +374,27 @@ cdef class LeftArc:
@staticmethod @staticmethod
cdef inline weight_t cost(StateClass s, const void* _gold, attr_t label) nogil: cdef inline weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
gold = <const GoldParseStateC*>_gold gold = <const GoldParseStateC*>_gold
return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label) return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label)
@staticmethod @staticmethod
cdef inline weight_t move_cost(StateClass s, const void* _gold) nogil: cdef inline weight_t move_cost(StateClass s, const GoldParseStateC* gold) nogil:
gold = <const GoldParseStateC*>_gold cdef weight_t cost = 0
if arc_is_gold(gold, s.S(0), s.B(0)): s0 = s.S(0)
return 0 b0 = s.B(0)
elif s.c.shifted[s.B(0)]: if arc_is_gold(gold, b0, s0):
return push_cost(s, gold, s.B(0)) # Have a negative cost if we 'recover' from the wrong dependency
return 0 if not s.has_head(s0) else -1
else: else:
return push_cost(s, gold, s.B(0)) + arc_cost(s, gold, s.S(0), s.B(0)) # Account for deps we might lose between S0 and stack
if not s.has_head(s0):
cost += gold.n_kids_in_stack[s0]
if is_head_in_buffer(gold, s0):
cost += 1
return cost + pop_cost(s, gold, s.S(0)) + arc_cost(s, gold, s.B(0), s.S(0))
@staticmethod @staticmethod
cdef weight_t label_cost(StateClass s, const void* _gold, attr_t label) nogil: cdef inline weight_t label_cost(StateClass s, const GoldParseStateC* gold, attr_t label) nogil:
gold = <const GoldParseStateC*>_gold return arc_is_gold(gold, s.B(0), s.S(0)) and not label_is_gold(gold, s.B(0), s.S(0), label)
return arc_is_gold(gold, s.S(0), s.B(0)) and not label_is_gold(gold, s.S(0), s.B(0), label)
cdef class RightArc: cdef class RightArc:
@ -502,9 +562,6 @@ cdef class ArcEager(TransitionSystem):
def action_types(self): def action_types(self):
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK) return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
def get_cost(self, StateClass state, Example gold, action):
raise NotImplementedError
def transition(self, StateClass state, action): def transition(self, StateClass state, action):
cdef Transition t = self.lookup_transition(action) cdef Transition t = self.lookup_transition(action)
t.do(state.c, t.label) t.do(state.c, t.label)
@ -619,45 +676,32 @@ cdef class ArcEager(TransitionSystem):
output[i] = self.c[i].is_valid(st, self.c[i].label) output[i] = self.c[i].is_valid(st, self.c[i].label)
else: else:
output[i] = is_valid[self.c[i].move] output[i] = is_valid[self.c[i].move]
def get_cost(self, StateClass stcls, gold, int i):
if not isinstance(gold, ArcEagerGold):
raise TypeError("Expected ArcEagerGold")
cdef ArcEagerGold gold_ = gold
gold_state = gold_.c
n_gold = 0
if self.c[i].is_valid(stcls.c, self.c[i].label):
cost = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
else:
cost = 9000
return cost
cdef int set_costs(self, int* is_valid, weight_t* costs, cdef int set_costs(self, int* is_valid, weight_t* costs,
StateClass stcls, gold) except -1: StateClass stcls, gold) except -1:
gold_state = (<ArcEagerGold>gold).c if not isinstance(gold, ArcEagerGold):
cdef int i, move raise TypeError("Expected ArcEagerGold")
cdef attr_t label cdef ArcEagerGold gold_ = gold
cdef label_cost_func_t[N_MOVES] label_cost_funcs gold_.update(stcls)
cdef move_cost_func_t[N_MOVES] move_cost_funcs gold_state = gold_.c
cdef weight_t[N_MOVES] move_costs
for i in range(N_MOVES):
move_costs[i] = 9000
move_cost_funcs[SHIFT] = Shift.move_cost
move_cost_funcs[REDUCE] = Reduce.move_cost
move_cost_funcs[LEFT] = LeftArc.move_cost
move_cost_funcs[RIGHT] = RightArc.move_cost
move_cost_funcs[BREAK] = Break.move_cost
label_cost_funcs[SHIFT] = Shift.label_cost
label_cost_funcs[REDUCE] = Reduce.label_cost
label_cost_funcs[LEFT] = LeftArc.label_cost
label_cost_funcs[RIGHT] = RightArc.label_cost
label_cost_funcs[BREAK] = Break.label_cost
cdef attr_t* labels = gold_state.labels
cdef int32_t* heads = gold_state.heads
n_gold = 0 n_gold = 0
for i in range(self.n_moves): for i in range(self.n_moves):
if self.c[i].is_valid(stcls.c, self.c[i].label): if self.c[i].is_valid(stcls.c, self.c[i].label):
is_valid[i] = True is_valid[i] = True
move = self.c[i].move costs[i] = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
label = self.c[i].label n_gold += 1
if move_costs[move] == 9000:
move_costs[move] = move_cost_funcs[move](stcls, &gold_state)
move_cost = move_costs[move]
label_cost = label_cost_funcs[move](stcls, &gold_state, label)
costs[i] = move_cost + label_cost
n_gold += costs[i] <= 0
print(move, label, costs[i])
else: else:
is_valid[i] = False is_valid[i] = False
costs[i] = 9000 costs[i] = 9000

View File

@ -1,4 +1,5 @@
# cython: infer_types=True # cython: infer_types=True
from __future__ import print_function
from cpython.ref cimport Py_INCREF from cpython.ref cimport Py_INCREF
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
@ -67,11 +68,13 @@ cdef class TransitionSystem:
costs = <float*>mem.alloc(self.n_moves, sizeof(float)) costs = <float*>mem.alloc(self.n_moves, sizeof(float))
is_valid = <int*>mem.alloc(self.n_moves, sizeof(int)) is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))
cdef StateClass state = StateClass(example.predicted, offset=0) cdef StateClass state
self.initialize_state(state.c) states, golds, n_steps = self.init_gold_batch([example])
state = states[0]
gold = golds[0]
history = [] history = []
while not state.is_final(): while not state.is_final():
self.set_costs(is_valid, costs, state, example) self.set_costs(is_valid, costs, state, gold)
for i in range(self.n_moves): for i in range(self.n_moves):
if is_valid[i] and costs[i] <= 0: if is_valid[i] and costs[i] <= 0:
action = self.c[i] action = self.c[i]