From e9860daf4b3063e5129de3a43945b14b826af3ca Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 21 Jun 2020 23:25:29 +0200 Subject: [PATCH] Update ArcEager oracle Fix Break oracle --- spacy/syntax/arc_eager.pyx | 138 +++++++++++++++++++++++++++++++++---- 1 file changed, 126 insertions(+), 12 deletions(-) diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 3d9071bcb..76f5e7e04 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -53,6 +53,8 @@ cdef enum: HEAD_IN_STACK = 0 HEAD_IN_BUFFER HEAD_UNKNOWN + IS_SENT_START + SENT_START_UNKNOWN cdef struct GoldParseStateC: @@ -76,6 +78,43 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp gs.n_kids_in_stack = mem.alloc(gs.length, sizeof(gs.n_kids_in_stack[0])) heads, labels = example.get_aligned_parse(projectivize=True) + sent_starts = example.get_aligned("SENT_START") + for i, is_sent_start in enumerate(sent_starts): + if is_sent_start == True: + gs.state_bits[i] = set_state_flag( + gs.state_bits[i], + IS_SENT_START, + 1 + ) + gs.state_bits[i] = set_state_flag( + gs.state_bits[i], + SENT_START_UNKNOWN, + 0 + ) + + elif is_sent_start is None: + gs.state_bits[i] = set_state_flag( + gs.state_bits[i], + SENT_START_UNKNOWN, + 1 + ) + gs.state_bits[i] = set_state_flag( + gs.state_bits[i], + IS_SENT_START, + 0 + ) + else: + gs.state_bits[i] = set_state_flag( + gs.state_bits[i], + SENT_START_UNKNOWN, + 0 + ) + gs.state_bits[i] = set_state_flag( + gs.state_bits[i], + IS_SENT_START, + 0 + ) + cdef TokenC ref_tok for i, (head, label) in enumerate(zip(heads, labels)): if head is not None: @@ -220,6 +259,13 @@ cdef int is_head_in_buffer(const GoldParseStateC* gold, int i) nogil: cdef int is_head_unknown(const GoldParseStateC* gold, int i) nogil: return check_state_gold(gold.state_bits[i], HEAD_UNKNOWN) +cdef int is_sent_start(const GoldParseStateC* gold, int i) nogil: + return check_state_gold(gold.state_bits[i], IS_SENT_START) + +cdef int is_sent_start_unknown(const GoldParseStateC* gold, int i) nogil: + return check_state_gold(gold.state_bits[i], SENT_START_UNKNOWN) + + # Helper functions for the arc-eager oracle cdef weight_t push_cost(StateClass stcls, const void* _gold, int target) nogil: @@ -251,7 +297,7 @@ cdef weight_t arc_cost(StateClass stcls, const void* _gold, int head, int child) elif stcls.H(child) == gold.heads[child]: return 1 # Head in buffer - elif gold.heads[child] >= stcls.B(0) and stcls.B(1) != 0: + elif is_head_in_buffer(gold, child): return 1 else: return 0 @@ -452,15 +498,15 @@ cdef class Break: @staticmethod cdef inline weight_t move_cost(StateClass s, const void* _gold) nogil: gold = _gold - cdef weight_t cost = 0 - cdef int i, j, S_i, B_i + cost = 0 for i in range(s.stack_depth()): S_i = s.S(i) cost += gold.n_kids_in_buffer[S_i] if is_head_in_buffer(gold, S_i): cost += 1 - # Check for sentence boundary --- if it's here, we can't have any deps - # between stack and buffer, so rest of action is irrelevant. + # It's weird not to check the gold sentence boundaries but if we do, + # we can't account for "sunk costs", i.e. situations where we're already + # wrong. s0_root = _get_root(s.S(0), gold) b0_root = _get_root(s.B(0), gold) if s0_root != b0_root or s0_root == -1 or b0_root == -1: @@ -538,6 +584,7 @@ cdef class ArcEager(TransitionSystem): for label, freq in list(label_freqs.items()): if freq < min_freq: label_freqs.pop(label) + print("Removing", action, label, freq) # Ensure these actions are present actions[BREAK].setdefault('ROOT', 0) if kwargs.get("learn_tokens") is True: @@ -588,7 +635,7 @@ cdef class ArcEager(TransitionSystem): for i in range(self.n_moves): if self.c[i].move == move and self.c[i].label == label: return self.c[i] - return Transition(clas=0, move=MISSING, label=0) + raise KeyError(f"Unknown transition: {name}") def move_name(self, int move, attr_t label): label_str = self.strings[label] @@ -691,12 +738,79 @@ cdef class ArcEager(TransitionSystem): if self.c[i].is_valid(stcls.c, self.c[i].label): is_valid[i] = True costs[i] = self.c[i].get_cost(stcls, &gold_state, self.c[i].label) - n_gold += 1 + n_gold += costs[i] <= 0 else: is_valid[i] = False costs[i] = 9000 - if n_gold < 1: - raise ValueError - #failure_state = stcls.print_state([t.text for t in example]) - #raise ValueError( - # Errors.E021.format(n_actions=self.n_moves, state=failure_state)) + #if n_gold < 1: + # raise ValueError + # #failure_state = stcls.print_state([t.text for t in example]) + # #raise ValueError( + # # Errors.E021.format(n_actions=self.n_moves, state=failure_state)) + + def get_oracle_sequence(self, Example example): + cdef Pool mem = Pool() + # n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc + assert self.n_moves > 0 + costs = mem.alloc(self.n_moves, sizeof(float)) + is_valid = mem.alloc(self.n_moves, sizeof(int)) + + cdef StateClass state + cdef ArcEagerGold gold + states, golds, n_steps = self.init_gold_batch([example]) + state = states[0] + gold = golds[0] + history = [] + debug_log = [] + while not state.is_final(): + self.set_costs(is_valid, costs, state, gold) + for i in range(self.n_moves): + if is_valid[i] and costs[i] <= 0: + action = self.c[i] + history.append(i) + s0 = state.S(0) + b0 = state.B(0) + debug_log.append(" ".join(( + self.get_class_name(i), + "S0=", (example.x[s0].text if s0 >= 0 else "__"), + "B0=", (example.x[b0].text if b0 >= 0 else "__"), + "S0 head?", str(state.has_head(state.S(0))), + ))) + action.do(state.c, action.label) + break + else: + print("Actions") + for i in range(self.n_moves): + print(self.get_class_name(i)) + print("Gold") + for token in example.y: + print(token.i, token.text, token.dep_, token.head.text) + aligned_heads, aligned_labels = example.get_aligned_parse() + print("Aligned heads") + for i, head in enumerate(aligned_heads): + print(example.x[i], example.x[head] if head is not None else "__") + + print("Predicted tokens") + print([(w.i, w.text) for w in example.x]) + s0 = state.S(0) + b0 = state.B(0) + debug_log.append(" ".join(( + "?", + "S0=", (example.x[s0].text if s0 >= 0 else "-"), + "B0=", (example.x[b0].text if b0 >= 0 else "-"), + "S0 head?", str(state.has_head(state.S(0))), + ))) + s0 = state.S(0) + b0 = state.B(0) + print("\n".join(debug_log)) + print("Arc is gold B0, S0?", arc_is_gold(&gold.c, b0, s0)) + print("Arc is gold S0, B0?", arc_is_gold(&gold.c, s0, b0)) + print("is_head_unknown(s0)", is_head_unknown(&gold.c, s0)) + print("is_head_unknown(b0)", is_head_unknown(&gold.c, b0)) + print("b0", b0, "gold.heads[s0]", gold.c.heads[s0]) + print("Stack", [example.x[i] for i in state.stack]) + print("Buffer", [example.x[i] for i in state.queue]) + raise ValueError(Errors.E024) + return history + +