From 5ca4c19ef229c218d290ea45c0b90b6c7e673a94 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 21 Jun 2020 01:01:09 +0200 Subject: [PATCH] Work on parser oracle Update arc_eager oracle Restore ArcEager.get_cost function Update transition system --- spacy/syntax/arc_eager.pyx | 146 +++++++++++++++++++---------- spacy/syntax/transition_system.pyx | 9 +- 2 files changed, 101 insertions(+), 54 deletions(-) diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index b0fedd6c4..c7ecbceea 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -82,7 +82,7 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp gold_i = cand_to_gold[cand_i] if gold_i is not None: # Alignment found ref_tok = example.y.c[gold_i] - gold_head = gold_to_cand[ref_tok.head + gold_i] + gold_head = gold_to_cand[gold_i + ref_tok.head] if gold_head is not None: gs.heads[cand_i] = gold_head gs.labels[cand_i] = ref_tok.dep @@ -106,17 +106,17 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp stack_words = set() for i in range(stcls.stack_depth()): s_i = stcls.S(i) - head = s_i + gs.heads[s_i] + head = gs.heads[s_i] gs.n_kids_in_stack[head] += 1 stack_words.add(s_i) buffer_words = set() for i in range(stcls.buffer_length()): b_i = stcls.B(i) - head = b_i + gs.heads[b_i] + head = gs.heads[b_i] gs.n_kids_in_buffer[head] += 1 buffer_words.add(b_i) for i in range(gs.length): - head = i + gs.heads[i] + head = gs.heads[i] if head in stack_words: gs.state_bits[i] = set_state_flag( gs.state_bits[i], @@ -142,6 +142,58 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp return gs +cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) except *: + for i in range(gs.length): + gs.state_bits[i] = set_state_flag( + gs.state_bits[i], + HEAD_IN_BUFFER, + 0 + ) + gs.state_bits[i] = set_state_flag( + gs.state_bits[i], + HEAD_IN_STACK, + 0 + ) + gs.n_kids_in_stack[i] = 0 + gs.n_kids_in_buffer[i] = 0 + stack_words = set() + for i in range(stcls.stack_depth()): + s_i = stcls.S(i) + head = gs.heads[s_i] + gs.n_kids_in_stack[head] += 1 + stack_words.add(s_i) + buffer_words = set() + for i in range(stcls.buffer_length()): + b_i = stcls.B(i) + head = gs.heads[b_i] + gs.n_kids_in_buffer[head] += 1 + buffer_words.add(b_i) + for i in range(gs.length): + head = gs.heads[i] + if head in stack_words: + gs.state_bits[i] = set_state_flag( + gs.state_bits[i], + HEAD_IN_STACK, + 1 + ) + gs.state_bits[i] = set_state_flag( + gs.state_bits[i], + HEAD_IN_BUFFER, + 0 + ) + elif head in buffer_words: + gs.state_bits[i] = set_state_flag( + gs.state_bits[i], + HEAD_IN_STACK, + 0 + ) + gs.state_bits[i] = set_state_flag( + gs.state_bits[i], + HEAD_IN_BUFFER, + 1 + ) + + cdef class ArcEagerGold: cdef GoldParseStateC c cdef Pool mem @@ -150,6 +202,9 @@ cdef class ArcEagerGold: self.mem = Pool() self.c = create_gold_state(self.mem, stcls, example) + def update(self, StateClass stcls): + update_gold_state(&self.c, stcls) + cdef int check_state_gold(char state_bits, char flag) nogil: @@ -183,7 +238,7 @@ cdef weight_t push_cost(StateClass stcls, const void* _gold, int target) nogil: cdef weight_t cost = 0 if is_head_in_stack(gold, target): cost += 1 - cost += gold.n_kids_in_buffer[target] + cost += gold.n_kids_in_stack[target] if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0: cost += 1 return cost @@ -319,22 +374,27 @@ cdef class LeftArc: @staticmethod cdef inline weight_t cost(StateClass s, const void* _gold, attr_t label) nogil: gold = _gold - return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label) + return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label) @staticmethod - cdef inline weight_t move_cost(StateClass s, const void* _gold) nogil: - gold = _gold - if arc_is_gold(gold, s.S(0), s.B(0)): - return 0 - elif s.c.shifted[s.B(0)]: - return push_cost(s, gold, s.B(0)) + cdef inline weight_t move_cost(StateClass s, const GoldParseStateC* gold) nogil: + cdef weight_t cost = 0 + s0 = s.S(0) + b0 = s.B(0) + if arc_is_gold(gold, b0, s0): + # Have a negative cost if we 'recover' from the wrong dependency + return 0 if not s.has_head(s0) else -1 else: - return push_cost(s, gold, s.B(0)) + arc_cost(s, gold, s.S(0), s.B(0)) + # Account for deps we might lose between S0 and stack + if not s.has_head(s0): + cost += gold.n_kids_in_stack[s0] + if is_head_in_buffer(gold, s0): + cost += 1 + return cost + pop_cost(s, gold, s.S(0)) + arc_cost(s, gold, s.B(0), s.S(0)) @staticmethod - cdef weight_t label_cost(StateClass s, const void* _gold, attr_t label) nogil: - gold = _gold - return arc_is_gold(gold, s.S(0), s.B(0)) and not label_is_gold(gold, s.S(0), s.B(0), label) + cdef inline weight_t label_cost(StateClass s, const GoldParseStateC* gold, attr_t label) nogil: + return arc_is_gold(gold, s.B(0), s.S(0)) and not label_is_gold(gold, s.B(0), s.S(0), label) cdef class RightArc: @@ -502,9 +562,6 @@ cdef class ArcEager(TransitionSystem): def action_types(self): return (SHIFT, REDUCE, LEFT, RIGHT, BREAK) - def get_cost(self, StateClass state, Example gold, action): - raise NotImplementedError - def transition(self, StateClass state, action): cdef Transition t = self.lookup_transition(action) t.do(state.c, t.label) @@ -619,45 +676,32 @@ cdef class ArcEager(TransitionSystem): output[i] = self.c[i].is_valid(st, self.c[i].label) else: output[i] = is_valid[self.c[i].move] + + def get_cost(self, StateClass stcls, gold, int i): + if not isinstance(gold, ArcEagerGold): + raise TypeError("Expected ArcEagerGold") + cdef ArcEagerGold gold_ = gold + gold_state = gold_.c + n_gold = 0 + if self.c[i].is_valid(stcls.c, self.c[i].label): + cost = self.c[i].get_cost(stcls, &gold_state, self.c[i].label) + else: + cost = 9000 + return cost cdef int set_costs(self, int* is_valid, weight_t* costs, StateClass stcls, gold) except -1: - gold_state = (gold).c - cdef int i, move - cdef attr_t label - cdef label_cost_func_t[N_MOVES] label_cost_funcs - cdef move_cost_func_t[N_MOVES] move_cost_funcs - cdef weight_t[N_MOVES] move_costs - for i in range(N_MOVES): - move_costs[i] = 9000 - move_cost_funcs[SHIFT] = Shift.move_cost - move_cost_funcs[REDUCE] = Reduce.move_cost - move_cost_funcs[LEFT] = LeftArc.move_cost - move_cost_funcs[RIGHT] = RightArc.move_cost - move_cost_funcs[BREAK] = Break.move_cost - - label_cost_funcs[SHIFT] = Shift.label_cost - label_cost_funcs[REDUCE] = Reduce.label_cost - label_cost_funcs[LEFT] = LeftArc.label_cost - label_cost_funcs[RIGHT] = RightArc.label_cost - label_cost_funcs[BREAK] = Break.label_cost - - cdef attr_t* labels = gold_state.labels - cdef int32_t* heads = gold_state.heads - + if not isinstance(gold, ArcEagerGold): + raise TypeError("Expected ArcEagerGold") + cdef ArcEagerGold gold_ = gold + gold_.update(stcls) + gold_state = gold_.c n_gold = 0 for i in range(self.n_moves): if self.c[i].is_valid(stcls.c, self.c[i].label): is_valid[i] = True - move = self.c[i].move - label = self.c[i].label - if move_costs[move] == 9000: - move_costs[move] = move_cost_funcs[move](stcls, &gold_state) - move_cost = move_costs[move] - label_cost = label_cost_funcs[move](stcls, &gold_state, label) - costs[i] = move_cost + label_cost - n_gold += costs[i] <= 0 - print(move, label, costs[i]) + costs[i] = self.c[i].get_cost(stcls, &gold_state, self.c[i].label) + n_gold += 1 else: is_valid[i] = False costs[i] = 9000 diff --git a/spacy/syntax/transition_system.pyx b/spacy/syntax/transition_system.pyx index 319550161..46e438e4c 100644 --- a/spacy/syntax/transition_system.pyx +++ b/spacy/syntax/transition_system.pyx @@ -1,4 +1,5 @@ # cython: infer_types=True +from __future__ import print_function from cpython.ref cimport Py_INCREF from cymem.cymem cimport Pool @@ -67,11 +68,13 @@ cdef class TransitionSystem: costs = mem.alloc(self.n_moves, sizeof(float)) is_valid = mem.alloc(self.n_moves, sizeof(int)) - cdef StateClass state = StateClass(example.predicted, offset=0) - self.initialize_state(state.c) + cdef StateClass state + states, golds, n_steps = self.init_gold_batch([example]) + state = states[0] + gold = golds[0] history = [] while not state.is_final(): - self.set_costs(is_valid, costs, state, example) + self.set_costs(is_valid, costs, state, gold) for i in range(self.n_moves): if is_valid[i] and costs[i] <= 0: action = self.c[i]