Fix arc_eager oracle

This commit is contained in:
Matthew Honnibal 2020-06-23 22:58:12 +02:00
parent a68d0e63f0
commit 420a986d15

View File

@ -200,7 +200,6 @@ cdef class ArcEagerGold:
sent_starts = example.get_aligned("SENT_START") sent_starts = example.get_aligned("SENT_START")
assert len(heads) == len(labels) == len(sent_starts) assert len(heads) == len(labels) == len(sent_starts)
self.c = create_gold_state(self.mem, stcls, heads, labels, sent_starts) self.c = create_gold_state(self.mem, stcls, heads, labels, sent_starts)
self.update(stcls)
def update(self, StateClass stcls): def update(self, StateClass stcls):
update_gold_state(&self.c, stcls) update_gold_state(&self.c, stcls)
@ -577,17 +576,12 @@ cdef class ArcEager(TransitionSystem):
def is_gold_parse(self, StateClass state, gold): def is_gold_parse(self, StateClass state, gold):
raise NotImplementedError raise NotImplementedError
def has_gold(self, gold, start=0, end=None):
raise NotImplementedError
def preprocess_gold(self, example):
raise NotImplementedError
def init_gold_batch(self, examples): def init_gold_batch(self, examples):
examples = [eg for eg in examples if self.has_gold(eg)]
states = self.init_batch([eg.predicted for eg in examples]) states = self.init_batch([eg.predicted for eg in examples])
keeps = [i for i, s in enumerate(states) if not s.is_final()] keeps = [i for i, s in enumerate(states) if not s.is_final()]
states = [states[i] for i in keeps]
golds = [ArcEagerGold(self, states[i], examples[i]) for i in keeps] golds = [ArcEagerGold(self, states[i], examples[i]) for i in keeps]
states = [states[i] for i in keeps]
for gold in golds: for gold in golds:
self._replace_unseen_labels(gold) self._replace_unseen_labels(gold)
n_steps = sum([len(s.queue) * 4 for s in states]) n_steps = sum([len(s.queue) * 4 for s in states])
@ -690,6 +684,9 @@ cdef class ArcEager(TransitionSystem):
doc.is_parsed = True doc.is_parsed = True
set_children_from_heads(doc.c, doc.length) set_children_from_heads(doc.c, doc.length)
def has_gold(self, Example eg):
return eg.y.is_parsed
cdef int set_valid(self, int* output, const StateC* st) nogil: cdef int set_valid(self, int* output, const StateC* st) nogil:
cdef bint[N_MOVES] is_valid cdef bint[N_MOVES] is_valid
is_valid[SHIFT] = Shift.is_valid(st, 0) is_valid[SHIFT] = Shift.is_valid(st, 0)
@ -736,21 +733,29 @@ cdef class ArcEager(TransitionSystem):
raise ValueError raise ValueError
def get_oracle_sequence(self, Example example): def get_oracle_sequence(self, Example example):
cdef StateClass state
cdef ArcEagerGold gold
states, golds, n_steps = self.init_gold_batch([example])
if not golds:
return []
cdef Pool mem = Pool() cdef Pool mem = Pool()
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc # n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
assert self.n_moves > 0 assert self.n_moves > 0
costs = <float*>mem.alloc(self.n_moves, sizeof(float)) costs = <float*>mem.alloc(self.n_moves, sizeof(float))
is_valid = <int*>mem.alloc(self.n_moves, sizeof(int)) is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))
cdef StateClass state
cdef ArcEagerGold gold
states, golds, n_steps = self.init_gold_batch([example])
state = states[0] state = states[0]
gold = golds[0] gold = golds[0]
history = [] history = []
debug_log = [] debug_log = []
failed = False
while not state.is_final(): while not state.is_final():
self.set_costs(is_valid, costs, state, gold) try:
self.set_costs(is_valid, costs, state, gold)
except ValueError:
failed = True
break
for i in range(self.n_moves): for i in range(self.n_moves):
if is_valid[i] and costs[i] <= 0: if is_valid[i] and costs[i] <= 0:
action = self.c[i] action = self.c[i]
@ -766,36 +771,39 @@ cdef class ArcEager(TransitionSystem):
action.do(state.c, action.label) action.do(state.c, action.label)
break break
else: else:
print("Actions") failed = False
for i in range(self.n_moves): break
print(self.get_class_name(i)) if failed:
print("Gold") print("Actions")
for token in example.y: for i in range(self.n_moves):
print(token.i, token.text, token.dep_, token.head.text) print(self.get_class_name(i))
aligned_heads, aligned_labels = example.get_aligned_parse() print("Gold")
print("Aligned heads") for token in example.y:
for i, head in enumerate(aligned_heads): print(token.i, token.text, token.dep_, token.head.text)
print(example.x[i], example.x[head] if head is not None else "__") aligned_heads, aligned_labels = example.get_aligned_parse()
print("Aligned heads")
for i, head in enumerate(aligned_heads):
print(example.x[i], example.x[head] if head is not None else "__")
print("Predicted tokens") print("Predicted tokens")
print([(w.i, w.text) for w in example.x]) print([(w.i, w.text) for w in example.x])
s0 = state.S(0) s0 = state.S(0)
b0 = state.B(0) b0 = state.B(0)
debug_log.append(" ".join(( debug_log.append(" ".join((
"?", "?",
"S0=", (example.x[s0].text if s0 >= 0 else "-"), "S0=", (example.x[s0].text if s0 >= 0 else "-"),
"B0=", (example.x[b0].text if b0 >= 0 else "-"), "B0=", (example.x[b0].text if b0 >= 0 else "-"),
"S0 head?", str(state.has_head(state.S(0))), "S0 head?", str(state.has_head(state.S(0))),
))) )))
s0 = state.S(0) s0 = state.S(0)
b0 = state.B(0) b0 = state.B(0)
print("\n".join(debug_log)) print("\n".join(debug_log))
print("Arc is gold B0, S0?", arc_is_gold(&gold.c, b0, s0)) print("Arc is gold B0, S0?", arc_is_gold(&gold.c, b0, s0))
print("Arc is gold S0, B0?", arc_is_gold(&gold.c, s0, b0)) print("Arc is gold S0, B0?", arc_is_gold(&gold.c, s0, b0))
print("is_head_unknown(s0)", is_head_unknown(&gold.c, s0)) print("is_head_unknown(s0)", is_head_unknown(&gold.c, s0))
print("is_head_unknown(b0)", is_head_unknown(&gold.c, b0)) print("is_head_unknown(b0)", is_head_unknown(&gold.c, b0))
print("b0", b0, "gold.heads[s0]", gold.c.heads[s0]) print("b0", b0, "gold.heads[s0]", gold.c.heads[s0])
print("Stack", [example.x[i] for i in state.stack]) print("Stack", [example.x[i] for i in state.stack])
print("Buffer", [example.x[i] for i in state.queue]) print("Buffer", [example.x[i] for i in state.queue])
raise ValueError(Errors.E024) raise ValueError(Errors.E024)
return history return history