mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-11 17:10:36 +03:00
Start debugging arc_eager oracle
This commit is contained in:
parent
2bcb5881d7
commit
0c10831b14
|
@ -76,12 +76,15 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp
|
||||||
gs.n_kids_in_stack = <int32_t*>mem.alloc(gs.length, sizeof(gs.n_kids_in_stack[0]))
|
gs.n_kids_in_stack = <int32_t*>mem.alloc(gs.length, sizeof(gs.n_kids_in_stack[0]))
|
||||||
|
|
||||||
cand_to_gold = example.alignment.cand_to_gold
|
cand_to_gold = example.alignment.cand_to_gold
|
||||||
|
gold_to_cand = example.alignment.cand_to_gold
|
||||||
cdef TokenC ref_tok
|
cdef TokenC ref_tok
|
||||||
for cand_i in range(example.x.length):
|
for cand_i in range(example.x.length):
|
||||||
gold_i = cand_to_gold[cand_i]
|
gold_i = cand_to_gold[cand_i]
|
||||||
if cand_i is not None: # Alignment found
|
if gold_i is not None: # Alignment found
|
||||||
ref_tok = example.y.c[gold_i]
|
ref_tok = example.y.c[gold_i]
|
||||||
gs.heads[cand_i] = ref_tok.head
|
gold_head = gold_to_cand[ref_tok.head + gold_i]
|
||||||
|
if gold_head is not None:
|
||||||
|
gs.heads[cand_i] = gold_head
|
||||||
gs.labels[cand_i] = ref_tok.dep
|
gs.labels[cand_i] = ref_tok.dep
|
||||||
gs.state_bits[cand_i] = set_state_flag(
|
gs.state_bits[cand_i] = set_state_flag(
|
||||||
gs.state_bits[cand_i],
|
gs.state_bits[cand_i],
|
||||||
|
@ -94,6 +97,12 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp
|
||||||
HEAD_UNKNOWN,
|
HEAD_UNKNOWN,
|
||||||
1
|
1
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
gs.state_bits[cand_i] = set_state_flag(
|
||||||
|
gs.state_bits[cand_i],
|
||||||
|
HEAD_UNKNOWN,
|
||||||
|
1
|
||||||
|
)
|
||||||
stack_words = set()
|
stack_words = set()
|
||||||
for i in range(stcls.stack_depth()):
|
for i in range(stcls.stack_depth()):
|
||||||
s_i = stcls.S(i)
|
s_i = stcls.S(i)
|
||||||
|
@ -135,6 +144,8 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp
|
||||||
|
|
||||||
cdef class ArcEagerGold:
|
cdef class ArcEagerGold:
|
||||||
cdef GoldParseStateC c
|
cdef GoldParseStateC c
|
||||||
|
cdef Pool mem
|
||||||
|
|
||||||
def __init__(self, ArcEager moves, StateClass stcls, Example example):
|
def __init__(self, ArcEager moves, StateClass stcls, Example example):
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
self.c = create_gold_state(self.mem, stcls, example)
|
self.c = create_gold_state(self.mem, stcls, example)
|
||||||
|
@ -610,9 +621,8 @@ cdef class ArcEager(TransitionSystem):
|
||||||
output[i] = is_valid[self.c[i].move]
|
output[i] = is_valid[self.c[i].move]
|
||||||
|
|
||||||
cdef int set_costs(self, int* is_valid, weight_t* costs,
|
cdef int set_costs(self, int* is_valid, weight_t* costs,
|
||||||
StateClass stcls, Example example) except -1:
|
StateClass stcls, gold) except -1:
|
||||||
cdef Pool mem = Pool()
|
gold_state = (<ArcEagerGold>gold).c
|
||||||
gold_state = create_gold_state(mem, stcls, example)
|
|
||||||
cdef int i, move
|
cdef int i, move
|
||||||
cdef attr_t label
|
cdef attr_t label
|
||||||
cdef label_cost_func_t[N_MOVES] label_cost_funcs
|
cdef label_cost_func_t[N_MOVES] label_cost_funcs
|
||||||
|
@ -643,16 +653,16 @@ cdef class ArcEager(TransitionSystem):
|
||||||
label = self.c[i].label
|
label = self.c[i].label
|
||||||
if move_costs[move] == 9000:
|
if move_costs[move] == 9000:
|
||||||
move_costs[move] = move_cost_funcs[move](stcls, &gold_state)
|
move_costs[move] = move_cost_funcs[move](stcls, &gold_state)
|
||||||
costs[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold_state, label)
|
move_cost = move_costs[move]
|
||||||
|
label_cost = label_cost_funcs[move](stcls, &gold_state, label)
|
||||||
|
costs[i] = move_cost + label_cost
|
||||||
n_gold += costs[i] <= 0
|
n_gold += costs[i] <= 0
|
||||||
|
print(move, label, costs[i])
|
||||||
else:
|
else:
|
||||||
is_valid[i] = False
|
is_valid[i] = False
|
||||||
costs[i] = 9000
|
costs[i] = 9000
|
||||||
if n_gold < 1:
|
if n_gold < 1:
|
||||||
# Check projectivity --- leading cause
|
raise ValueError
|
||||||
if is_nonproj_tree(example.get_field("HEAD")):
|
#failure_state = stcls.print_state([t.text for t in example])
|
||||||
raise ValueError(Errors.E020)
|
#raise ValueError(
|
||||||
else:
|
# Errors.E021.format(n_actions=self.n_moves, state=failure_state))
|
||||||
failure_state = stcls.print_state([t.text for t in example])
|
|
||||||
raise ValueError(Errors.E021.format(n_actions=self.n_moves,
|
|
||||||
state=failure_state))
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user