Start debugging arc_eager oracle

This commit is contained in:
Matthew Honnibal 2020-06-20 21:49:46 +02:00
parent b60eede321
commit 456e27dc8b

View File

@ -76,18 +76,27 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp
gs.n_kids_in_stack = <int32_t*>mem.alloc(gs.length, sizeof(gs.n_kids_in_stack[0])) gs.n_kids_in_stack = <int32_t*>mem.alloc(gs.length, sizeof(gs.n_kids_in_stack[0]))
cand_to_gold = example.alignment.cand_to_gold cand_to_gold = example.alignment.cand_to_gold
gold_to_cand = example.alignment.cand_to_gold
cdef TokenC ref_tok cdef TokenC ref_tok
for cand_i in range(example.x.length): for cand_i in range(example.x.length):
gold_i = cand_to_gold[cand_i] gold_i = cand_to_gold[cand_i]
if cand_i is not None: # Alignment found if gold_i is not None: # Alignment found
ref_tok = example.y.c[gold_i] ref_tok = example.y.c[gold_i]
gs.heads[cand_i] = ref_tok.head gold_head = gold_to_cand[ref_tok.head + gold_i]
gs.labels[cand_i] = ref_tok.dep if gold_head is not None:
gs.state_bits[cand_i] = set_state_flag( gs.heads[cand_i] = gold_head
gs.state_bits[cand_i], gs.labels[cand_i] = ref_tok.dep
HEAD_UNKNOWN, gs.state_bits[cand_i] = set_state_flag(
0 gs.state_bits[cand_i],
) HEAD_UNKNOWN,
0
)
else:
gs.state_bits[cand_i] = set_state_flag(
gs.state_bits[cand_i],
HEAD_UNKNOWN,
1
)
else: else:
gs.state_bits[cand_i] = set_state_flag( gs.state_bits[cand_i] = set_state_flag(
gs.state_bits[cand_i], gs.state_bits[cand_i],
@ -135,6 +144,8 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp
cdef class ArcEagerGold: cdef class ArcEagerGold:
cdef GoldParseStateC c cdef GoldParseStateC c
cdef Pool mem
def __init__(self, ArcEager moves, StateClass stcls, Example example): def __init__(self, ArcEager moves, StateClass stcls, Example example):
self.mem = Pool() self.mem = Pool()
self.c = create_gold_state(self.mem, stcls, example) self.c = create_gold_state(self.mem, stcls, example)
@ -610,9 +621,8 @@ cdef class ArcEager(TransitionSystem):
output[i] = is_valid[self.c[i].move] output[i] = is_valid[self.c[i].move]
cdef int set_costs(self, int* is_valid, weight_t* costs, cdef int set_costs(self, int* is_valid, weight_t* costs,
StateClass stcls, Example example) except -1: StateClass stcls, gold) except -1:
cdef Pool mem = Pool() gold_state = (<ArcEagerGold>gold).c
gold_state = create_gold_state(mem, stcls, example)
cdef int i, move cdef int i, move
cdef attr_t label cdef attr_t label
cdef label_cost_func_t[N_MOVES] label_cost_funcs cdef label_cost_func_t[N_MOVES] label_cost_funcs
@ -643,16 +653,16 @@ cdef class ArcEager(TransitionSystem):
label = self.c[i].label label = self.c[i].label
if move_costs[move] == 9000: if move_costs[move] == 9000:
move_costs[move] = move_cost_funcs[move](stcls, &gold_state) move_costs[move] = move_cost_funcs[move](stcls, &gold_state)
costs[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold_state, label) move_cost = move_costs[move]
label_cost = label_cost_funcs[move](stcls, &gold_state, label)
costs[i] = move_cost + label_cost
n_gold += costs[i] <= 0 n_gold += costs[i] <= 0
print(move, label, costs[i])
else: else:
is_valid[i] = False is_valid[i] = False
costs[i] = 9000 costs[i] = 9000
if n_gold < 1: if n_gold < 1:
# Check projectivity --- leading cause raise ValueError
if is_nonproj_tree(example.get_field("HEAD")): #failure_state = stcls.print_state([t.text for t in example])
raise ValueError(Errors.E020) #raise ValueError(
else: # Errors.E021.format(n_actions=self.n_moves, state=failure_state))
failure_state = stcls.print_state([t.text for t in example])
raise ValueError(Errors.E021.format(n_actions=self.n_moves,
state=failure_state))