diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 0dfcbf885..b0fedd6c4 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -76,18 +76,27 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp gs.n_kids_in_stack = mem.alloc(gs.length, sizeof(gs.n_kids_in_stack[0])) cand_to_gold = example.alignment.cand_to_gold + gold_to_cand = example.alignment.cand_to_gold cdef TokenC ref_tok for cand_i in range(example.x.length): gold_i = cand_to_gold[cand_i] - if cand_i is not None: # Alignment found + if gold_i is not None: # Alignment found ref_tok = example.y.c[gold_i] - gs.heads[cand_i] = ref_tok.head - gs.labels[cand_i] = ref_tok.dep - gs.state_bits[cand_i] = set_state_flag( - gs.state_bits[cand_i], - HEAD_UNKNOWN, - 0 - ) + gold_head = gold_to_cand[ref_tok.head + gold_i] + if gold_head is not None: + gs.heads[cand_i] = gold_head + gs.labels[cand_i] = ref_tok.dep + gs.state_bits[cand_i] = set_state_flag( + gs.state_bits[cand_i], + HEAD_UNKNOWN, + 0 + ) + else: + gs.state_bits[cand_i] = set_state_flag( + gs.state_bits[cand_i], + HEAD_UNKNOWN, + 1 + ) else: gs.state_bits[cand_i] = set_state_flag( gs.state_bits[cand_i], @@ -135,6 +144,8 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp cdef class ArcEagerGold: cdef GoldParseStateC c + cdef Pool mem + def __init__(self, ArcEager moves, StateClass stcls, Example example): self.mem = Pool() self.c = create_gold_state(self.mem, stcls, example) @@ -610,9 +621,8 @@ cdef class ArcEager(TransitionSystem): output[i] = is_valid[self.c[i].move] cdef int set_costs(self, int* is_valid, weight_t* costs, - StateClass stcls, Example example) except -1: - cdef Pool mem = Pool() - gold_state = create_gold_state(mem, stcls, example) + StateClass stcls, gold) except -1: + gold_state = (gold).c cdef int i, move cdef attr_t label cdef label_cost_func_t[N_MOVES] label_cost_funcs @@ -643,16 +653,16 @@ cdef class ArcEager(TransitionSystem): label = self.c[i].label if move_costs[move] == 9000: move_costs[move] = move_cost_funcs[move](stcls, &gold_state) - costs[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold_state, label) + move_cost = move_costs[move] + label_cost = label_cost_funcs[move](stcls, &gold_state, label) + costs[i] = move_cost + label_cost n_gold += costs[i] <= 0 + print(move, label, costs[i]) else: is_valid[i] = False costs[i] = 9000 if n_gold < 1: - # Check projectivity --- leading cause - if is_nonproj_tree(example.get_field("HEAD")): - raise ValueError(Errors.E020) - else: - failure_state = stcls.print_state([t.text for t in example]) - raise ValueError(Errors.E021.format(n_actions=self.n_moves, - state=failure_state)) + raise ValueError + #failure_state = stcls.print_state([t.text for t in example]) + #raise ValueError( + # Errors.E021.format(n_actions=self.n_moves, state=failure_state))