Update ArcEager oracle

Fix Break oracle
This commit is contained in:
Matthew Honnibal 2020-06-21 23:25:29 +02:00
parent 3354758351
commit e9860daf4b

View File

@ -53,6 +53,8 @@ cdef enum:
HEAD_IN_STACK = 0 HEAD_IN_STACK = 0
HEAD_IN_BUFFER HEAD_IN_BUFFER
HEAD_UNKNOWN HEAD_UNKNOWN
IS_SENT_START
SENT_START_UNKNOWN
cdef struct GoldParseStateC: cdef struct GoldParseStateC:
@ -76,6 +78,43 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp
gs.n_kids_in_stack = <int32_t*>mem.alloc(gs.length, sizeof(gs.n_kids_in_stack[0])) gs.n_kids_in_stack = <int32_t*>mem.alloc(gs.length, sizeof(gs.n_kids_in_stack[0]))
heads, labels = example.get_aligned_parse(projectivize=True) heads, labels = example.get_aligned_parse(projectivize=True)
sent_starts = example.get_aligned("SENT_START")
for i, is_sent_start in enumerate(sent_starts):
if is_sent_start == True:
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
IS_SENT_START,
1
)
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
SENT_START_UNKNOWN,
0
)
elif is_sent_start is None:
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
SENT_START_UNKNOWN,
1
)
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
IS_SENT_START,
0
)
else:
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
SENT_START_UNKNOWN,
0
)
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
IS_SENT_START,
0
)
cdef TokenC ref_tok cdef TokenC ref_tok
for i, (head, label) in enumerate(zip(heads, labels)): for i, (head, label) in enumerate(zip(heads, labels)):
if head is not None: if head is not None:
@ -220,6 +259,13 @@ cdef int is_head_in_buffer(const GoldParseStateC* gold, int i) nogil:
cdef int is_head_unknown(const GoldParseStateC* gold, int i) nogil: cdef int is_head_unknown(const GoldParseStateC* gold, int i) nogil:
return check_state_gold(gold.state_bits[i], HEAD_UNKNOWN) return check_state_gold(gold.state_bits[i], HEAD_UNKNOWN)
cdef int is_sent_start(const GoldParseStateC* gold, int i) nogil:
return check_state_gold(gold.state_bits[i], IS_SENT_START)
cdef int is_sent_start_unknown(const GoldParseStateC* gold, int i) nogil:
return check_state_gold(gold.state_bits[i], SENT_START_UNKNOWN)
# Helper functions for the arc-eager oracle # Helper functions for the arc-eager oracle
cdef weight_t push_cost(StateClass stcls, const void* _gold, int target) nogil: cdef weight_t push_cost(StateClass stcls, const void* _gold, int target) nogil:
@ -251,7 +297,7 @@ cdef weight_t arc_cost(StateClass stcls, const void* _gold, int head, int child)
elif stcls.H(child) == gold.heads[child]: elif stcls.H(child) == gold.heads[child]:
return 1 return 1
# Head in buffer # Head in buffer
elif gold.heads[child] >= stcls.B(0) and stcls.B(1) != 0: elif is_head_in_buffer(gold, child):
return 1 return 1
else: else:
return 0 return 0
@ -452,15 +498,15 @@ cdef class Break:
@staticmethod @staticmethod
cdef inline weight_t move_cost(StateClass s, const void* _gold) nogil: cdef inline weight_t move_cost(StateClass s, const void* _gold) nogil:
gold = <const GoldParseStateC*>_gold gold = <const GoldParseStateC*>_gold
cdef weight_t cost = 0 cost = 0
cdef int i, j, S_i, B_i
for i in range(s.stack_depth()): for i in range(s.stack_depth()):
S_i = s.S(i) S_i = s.S(i)
cost += gold.n_kids_in_buffer[S_i] cost += gold.n_kids_in_buffer[S_i]
if is_head_in_buffer(gold, S_i): if is_head_in_buffer(gold, S_i):
cost += 1 cost += 1
# Check for sentence boundary --- if it's here, we can't have any deps # It's weird not to check the gold sentence boundaries but if we do,
# between stack and buffer, so rest of action is irrelevant. # we can't account for "sunk costs", i.e. situations where we're already
# wrong.
s0_root = _get_root(s.S(0), gold) s0_root = _get_root(s.S(0), gold)
b0_root = _get_root(s.B(0), gold) b0_root = _get_root(s.B(0), gold)
if s0_root != b0_root or s0_root == -1 or b0_root == -1: if s0_root != b0_root or s0_root == -1 or b0_root == -1:
@ -538,6 +584,7 @@ cdef class ArcEager(TransitionSystem):
for label, freq in list(label_freqs.items()): for label, freq in list(label_freqs.items()):
if freq < min_freq: if freq < min_freq:
label_freqs.pop(label) label_freqs.pop(label)
print("Removing", action, label, freq)
# Ensure these actions are present # Ensure these actions are present
actions[BREAK].setdefault('ROOT', 0) actions[BREAK].setdefault('ROOT', 0)
if kwargs.get("learn_tokens") is True: if kwargs.get("learn_tokens") is True:
@ -588,7 +635,7 @@ cdef class ArcEager(TransitionSystem):
for i in range(self.n_moves): for i in range(self.n_moves):
if self.c[i].move == move and self.c[i].label == label: if self.c[i].move == move and self.c[i].label == label:
return self.c[i] return self.c[i]
return Transition(clas=0, move=MISSING, label=0) raise KeyError(f"Unknown transition: {name}")
def move_name(self, int move, attr_t label): def move_name(self, int move, attr_t label):
label_str = self.strings[label] label_str = self.strings[label]
@ -691,12 +738,79 @@ cdef class ArcEager(TransitionSystem):
if self.c[i].is_valid(stcls.c, self.c[i].label): if self.c[i].is_valid(stcls.c, self.c[i].label):
is_valid[i] = True is_valid[i] = True
costs[i] = self.c[i].get_cost(stcls, &gold_state, self.c[i].label) costs[i] = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
n_gold += 1 n_gold += costs[i] <= 0
else: else:
is_valid[i] = False is_valid[i] = False
costs[i] = 9000 costs[i] = 9000
if n_gold < 1: #if n_gold < 1:
raise ValueError # raise ValueError
#failure_state = stcls.print_state([t.text for t in example]) # #failure_state = stcls.print_state([t.text for t in example])
#raise ValueError( # #raise ValueError(
# Errors.E021.format(n_actions=self.n_moves, state=failure_state)) # # Errors.E021.format(n_actions=self.n_moves, state=failure_state))
def get_oracle_sequence(self, Example example):
cdef Pool mem = Pool()
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
assert self.n_moves > 0
costs = <float*>mem.alloc(self.n_moves, sizeof(float))
is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))
cdef StateClass state
cdef ArcEagerGold gold
states, golds, n_steps = self.init_gold_batch([example])
state = states[0]
gold = golds[0]
history = []
debug_log = []
while not state.is_final():
self.set_costs(is_valid, costs, state, gold)
for i in range(self.n_moves):
if is_valid[i] and costs[i] <= 0:
action = self.c[i]
history.append(i)
s0 = state.S(0)
b0 = state.B(0)
debug_log.append(" ".join((
self.get_class_name(i),
"S0=", (example.x[s0].text if s0 >= 0 else "__"),
"B0=", (example.x[b0].text if b0 >= 0 else "__"),
"S0 head?", str(state.has_head(state.S(0))),
)))
action.do(state.c, action.label)
break
else:
print("Actions")
for i in range(self.n_moves):
print(self.get_class_name(i))
print("Gold")
for token in example.y:
print(token.i, token.text, token.dep_, token.head.text)
aligned_heads, aligned_labels = example.get_aligned_parse()
print("Aligned heads")
for i, head in enumerate(aligned_heads):
print(example.x[i], example.x[head] if head is not None else "__")
print("Predicted tokens")
print([(w.i, w.text) for w in example.x])
s0 = state.S(0)
b0 = state.B(0)
debug_log.append(" ".join((
"?",
"S0=", (example.x[s0].text if s0 >= 0 else "-"),
"B0=", (example.x[b0].text if b0 >= 0 else "-"),
"S0 head?", str(state.has_head(state.S(0))),
)))
s0 = state.S(0)
b0 = state.B(0)
print("\n".join(debug_log))
print("Arc is gold B0, S0?", arc_is_gold(&gold.c, b0, s0))
print("Arc is gold S0, B0?", arc_is_gold(&gold.c, s0, b0))
print("is_head_unknown(s0)", is_head_unknown(&gold.c, s0))
print("is_head_unknown(b0)", is_head_unknown(&gold.c, b0))
print("b0", b0, "gold.heads[s0]", gold.c.heads[s0])
print("Stack", [example.x[i] for i in state.stack])
print("Buffer", [example.x[i] for i in state.queue])
raise ValueError(Errors.E024)
return history