From 420a986d15ff56e26c38d377754bbb1eaad08c5e Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 23 Jun 2020 22:58:12 +0200 Subject: [PATCH] Fix arc_eager oracle --- spacy/syntax/arc_eager.pyx | 94 +++++++++++++++++++++----------------- 1 file changed, 51 insertions(+), 43 deletions(-) diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index a6bc10f0c..8a638f31d 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -200,7 +200,6 @@ cdef class ArcEagerGold: sent_starts = example.get_aligned("SENT_START") assert len(heads) == len(labels) == len(sent_starts) self.c = create_gold_state(self.mem, stcls, heads, labels, sent_starts) - self.update(stcls) def update(self, StateClass stcls): update_gold_state(&self.c, stcls) @@ -577,17 +576,12 @@ cdef class ArcEager(TransitionSystem): def is_gold_parse(self, StateClass state, gold): raise NotImplementedError - def has_gold(self, gold, start=0, end=None): - raise NotImplementedError - - def preprocess_gold(self, example): - raise NotImplementedError - def init_gold_batch(self, examples): + examples = [eg for eg in examples if self.has_gold(eg)] states = self.init_batch([eg.predicted for eg in examples]) keeps = [i for i, s in enumerate(states) if not s.is_final()] - states = [states[i] for i in keeps] golds = [ArcEagerGold(self, states[i], examples[i]) for i in keeps] + states = [states[i] for i in keeps] for gold in golds: self._replace_unseen_labels(gold) n_steps = sum([len(s.queue) * 4 for s in states]) @@ -690,6 +684,9 @@ cdef class ArcEager(TransitionSystem): doc.is_parsed = True set_children_from_heads(doc.c, doc.length) + def has_gold(self, Example eg): + return eg.y.is_parsed + cdef int set_valid(self, int* output, const StateC* st) nogil: cdef bint[N_MOVES] is_valid is_valid[SHIFT] = Shift.is_valid(st, 0) @@ -736,21 +733,29 @@ cdef class ArcEager(TransitionSystem): raise ValueError def get_oracle_sequence(self, Example example): + cdef StateClass state + cdef ArcEagerGold gold + states, golds, n_steps = self.init_gold_batch([example]) + if not golds: + return [] + cdef Pool mem = Pool() # n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc assert self.n_moves > 0 costs = mem.alloc(self.n_moves, sizeof(float)) is_valid = mem.alloc(self.n_moves, sizeof(int)) - cdef StateClass state - cdef ArcEagerGold gold - states, golds, n_steps = self.init_gold_batch([example]) state = states[0] gold = golds[0] history = [] debug_log = [] + failed = False while not state.is_final(): - self.set_costs(is_valid, costs, state, gold) + try: + self.set_costs(is_valid, costs, state, gold) + except ValueError: + failed = True + break for i in range(self.n_moves): if is_valid[i] and costs[i] <= 0: action = self.c[i] @@ -766,36 +771,39 @@ cdef class ArcEager(TransitionSystem): action.do(state.c, action.label) break else: - print("Actions") - for i in range(self.n_moves): - print(self.get_class_name(i)) - print("Gold") - for token in example.y: - print(token.i, token.text, token.dep_, token.head.text) - aligned_heads, aligned_labels = example.get_aligned_parse() - print("Aligned heads") - for i, head in enumerate(aligned_heads): - print(example.x[i], example.x[head] if head is not None else "__") + failed = False + break + if failed: + print("Actions") + for i in range(self.n_moves): + print(self.get_class_name(i)) + print("Gold") + for token in example.y: + print(token.i, token.text, token.dep_, token.head.text) + aligned_heads, aligned_labels = example.get_aligned_parse() + print("Aligned heads") + for i, head in enumerate(aligned_heads): + print(example.x[i], example.x[head] if head is not None else "__") - print("Predicted tokens") - print([(w.i, w.text) for w in example.x]) - s0 = state.S(0) - b0 = state.B(0) - debug_log.append(" ".join(( - "?", - "S0=", (example.x[s0].text if s0 >= 0 else "-"), - "B0=", (example.x[b0].text if b0 >= 0 else "-"), - "S0 head?", str(state.has_head(state.S(0))), - ))) - s0 = state.S(0) - b0 = state.B(0) - print("\n".join(debug_log)) - print("Arc is gold B0, S0?", arc_is_gold(&gold.c, b0, s0)) - print("Arc is gold S0, B0?", arc_is_gold(&gold.c, s0, b0)) - print("is_head_unknown(s0)", is_head_unknown(&gold.c, s0)) - print("is_head_unknown(b0)", is_head_unknown(&gold.c, b0)) - print("b0", b0, "gold.heads[s0]", gold.c.heads[s0]) - print("Stack", [example.x[i] for i in state.stack]) - print("Buffer", [example.x[i] for i in state.queue]) - raise ValueError(Errors.E024) + print("Predicted tokens") + print([(w.i, w.text) for w in example.x]) + s0 = state.S(0) + b0 = state.B(0) + debug_log.append(" ".join(( + "?", + "S0=", (example.x[s0].text if s0 >= 0 else "-"), + "B0=", (example.x[b0].text if b0 >= 0 else "-"), + "S0 head?", str(state.has_head(state.S(0))), + ))) + s0 = state.S(0) + b0 = state.B(0) + print("\n".join(debug_log)) + print("Arc is gold B0, S0?", arc_is_gold(&gold.c, b0, s0)) + print("Arc is gold S0, B0?", arc_is_gold(&gold.c, s0, b0)) + print("is_head_unknown(s0)", is_head_unknown(&gold.c, s0)) + print("is_head_unknown(b0)", is_head_unknown(&gold.c, b0)) + print("b0", b0, "gold.heads[s0]", gold.c.heads[s0]) + print("Stack", [example.x[i] for i in state.stack]) + print("Buffer", [example.x[i] for i in state.queue]) + raise ValueError(Errors.E024) return history