mirror of
https://github.com/explosion/spaCy.git
synced 2025-03-12 07:15:48 +03:00
Update ArcEager oracle
Fix Break oracle
This commit is contained in:
parent
3354758351
commit
e9860daf4b
|
@ -53,6 +53,8 @@ cdef enum:
|
||||||
HEAD_IN_STACK = 0
|
HEAD_IN_STACK = 0
|
||||||
HEAD_IN_BUFFER
|
HEAD_IN_BUFFER
|
||||||
HEAD_UNKNOWN
|
HEAD_UNKNOWN
|
||||||
|
IS_SENT_START
|
||||||
|
SENT_START_UNKNOWN
|
||||||
|
|
||||||
|
|
||||||
cdef struct GoldParseStateC:
|
cdef struct GoldParseStateC:
|
||||||
|
@ -76,6 +78,43 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp
|
||||||
gs.n_kids_in_stack = <int32_t*>mem.alloc(gs.length, sizeof(gs.n_kids_in_stack[0]))
|
gs.n_kids_in_stack = <int32_t*>mem.alloc(gs.length, sizeof(gs.n_kids_in_stack[0]))
|
||||||
|
|
||||||
heads, labels = example.get_aligned_parse(projectivize=True)
|
heads, labels = example.get_aligned_parse(projectivize=True)
|
||||||
|
sent_starts = example.get_aligned("SENT_START")
|
||||||
|
for i, is_sent_start in enumerate(sent_starts):
|
||||||
|
if is_sent_start == True:
|
||||||
|
gs.state_bits[i] = set_state_flag(
|
||||||
|
gs.state_bits[i],
|
||||||
|
IS_SENT_START,
|
||||||
|
1
|
||||||
|
)
|
||||||
|
gs.state_bits[i] = set_state_flag(
|
||||||
|
gs.state_bits[i],
|
||||||
|
SENT_START_UNKNOWN,
|
||||||
|
0
|
||||||
|
)
|
||||||
|
|
||||||
|
elif is_sent_start is None:
|
||||||
|
gs.state_bits[i] = set_state_flag(
|
||||||
|
gs.state_bits[i],
|
||||||
|
SENT_START_UNKNOWN,
|
||||||
|
1
|
||||||
|
)
|
||||||
|
gs.state_bits[i] = set_state_flag(
|
||||||
|
gs.state_bits[i],
|
||||||
|
IS_SENT_START,
|
||||||
|
0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
gs.state_bits[i] = set_state_flag(
|
||||||
|
gs.state_bits[i],
|
||||||
|
SENT_START_UNKNOWN,
|
||||||
|
0
|
||||||
|
)
|
||||||
|
gs.state_bits[i] = set_state_flag(
|
||||||
|
gs.state_bits[i],
|
||||||
|
IS_SENT_START,
|
||||||
|
0
|
||||||
|
)
|
||||||
|
|
||||||
cdef TokenC ref_tok
|
cdef TokenC ref_tok
|
||||||
for i, (head, label) in enumerate(zip(heads, labels)):
|
for i, (head, label) in enumerate(zip(heads, labels)):
|
||||||
if head is not None:
|
if head is not None:
|
||||||
|
@ -220,6 +259,13 @@ cdef int is_head_in_buffer(const GoldParseStateC* gold, int i) nogil:
|
||||||
cdef int is_head_unknown(const GoldParseStateC* gold, int i) nogil:
|
cdef int is_head_unknown(const GoldParseStateC* gold, int i) nogil:
|
||||||
return check_state_gold(gold.state_bits[i], HEAD_UNKNOWN)
|
return check_state_gold(gold.state_bits[i], HEAD_UNKNOWN)
|
||||||
|
|
||||||
|
cdef int is_sent_start(const GoldParseStateC* gold, int i) nogil:
|
||||||
|
return check_state_gold(gold.state_bits[i], IS_SENT_START)
|
||||||
|
|
||||||
|
cdef int is_sent_start_unknown(const GoldParseStateC* gold, int i) nogil:
|
||||||
|
return check_state_gold(gold.state_bits[i], SENT_START_UNKNOWN)
|
||||||
|
|
||||||
|
|
||||||
# Helper functions for the arc-eager oracle
|
# Helper functions for the arc-eager oracle
|
||||||
|
|
||||||
cdef weight_t push_cost(StateClass stcls, const void* _gold, int target) nogil:
|
cdef weight_t push_cost(StateClass stcls, const void* _gold, int target) nogil:
|
||||||
|
@ -251,7 +297,7 @@ cdef weight_t arc_cost(StateClass stcls, const void* _gold, int head, int child)
|
||||||
elif stcls.H(child) == gold.heads[child]:
|
elif stcls.H(child) == gold.heads[child]:
|
||||||
return 1
|
return 1
|
||||||
# Head in buffer
|
# Head in buffer
|
||||||
elif gold.heads[child] >= stcls.B(0) and stcls.B(1) != 0:
|
elif is_head_in_buffer(gold, child):
|
||||||
return 1
|
return 1
|
||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
@ -452,15 +498,15 @@ cdef class Break:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef inline weight_t move_cost(StateClass s, const void* _gold) nogil:
|
cdef inline weight_t move_cost(StateClass s, const void* _gold) nogil:
|
||||||
gold = <const GoldParseStateC*>_gold
|
gold = <const GoldParseStateC*>_gold
|
||||||
cdef weight_t cost = 0
|
cost = 0
|
||||||
cdef int i, j, S_i, B_i
|
|
||||||
for i in range(s.stack_depth()):
|
for i in range(s.stack_depth()):
|
||||||
S_i = s.S(i)
|
S_i = s.S(i)
|
||||||
cost += gold.n_kids_in_buffer[S_i]
|
cost += gold.n_kids_in_buffer[S_i]
|
||||||
if is_head_in_buffer(gold, S_i):
|
if is_head_in_buffer(gold, S_i):
|
||||||
cost += 1
|
cost += 1
|
||||||
# Check for sentence boundary --- if it's here, we can't have any deps
|
# It's weird not to check the gold sentence boundaries but if we do,
|
||||||
# between stack and buffer, so rest of action is irrelevant.
|
# we can't account for "sunk costs", i.e. situations where we're already
|
||||||
|
# wrong.
|
||||||
s0_root = _get_root(s.S(0), gold)
|
s0_root = _get_root(s.S(0), gold)
|
||||||
b0_root = _get_root(s.B(0), gold)
|
b0_root = _get_root(s.B(0), gold)
|
||||||
if s0_root != b0_root or s0_root == -1 or b0_root == -1:
|
if s0_root != b0_root or s0_root == -1 or b0_root == -1:
|
||||||
|
@ -538,6 +584,7 @@ cdef class ArcEager(TransitionSystem):
|
||||||
for label, freq in list(label_freqs.items()):
|
for label, freq in list(label_freqs.items()):
|
||||||
if freq < min_freq:
|
if freq < min_freq:
|
||||||
label_freqs.pop(label)
|
label_freqs.pop(label)
|
||||||
|
print("Removing", action, label, freq)
|
||||||
# Ensure these actions are present
|
# Ensure these actions are present
|
||||||
actions[BREAK].setdefault('ROOT', 0)
|
actions[BREAK].setdefault('ROOT', 0)
|
||||||
if kwargs.get("learn_tokens") is True:
|
if kwargs.get("learn_tokens") is True:
|
||||||
|
@ -588,7 +635,7 @@ cdef class ArcEager(TransitionSystem):
|
||||||
for i in range(self.n_moves):
|
for i in range(self.n_moves):
|
||||||
if self.c[i].move == move and self.c[i].label == label:
|
if self.c[i].move == move and self.c[i].label == label:
|
||||||
return self.c[i]
|
return self.c[i]
|
||||||
return Transition(clas=0, move=MISSING, label=0)
|
raise KeyError(f"Unknown transition: {name}")
|
||||||
|
|
||||||
def move_name(self, int move, attr_t label):
|
def move_name(self, int move, attr_t label):
|
||||||
label_str = self.strings[label]
|
label_str = self.strings[label]
|
||||||
|
@ -691,12 +738,79 @@ cdef class ArcEager(TransitionSystem):
|
||||||
if self.c[i].is_valid(stcls.c, self.c[i].label):
|
if self.c[i].is_valid(stcls.c, self.c[i].label):
|
||||||
is_valid[i] = True
|
is_valid[i] = True
|
||||||
costs[i] = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
|
costs[i] = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
|
||||||
n_gold += 1
|
n_gold += costs[i] <= 0
|
||||||
else:
|
else:
|
||||||
is_valid[i] = False
|
is_valid[i] = False
|
||||||
costs[i] = 9000
|
costs[i] = 9000
|
||||||
if n_gold < 1:
|
#if n_gold < 1:
|
||||||
raise ValueError
|
# raise ValueError
|
||||||
#failure_state = stcls.print_state([t.text for t in example])
|
# #failure_state = stcls.print_state([t.text for t in example])
|
||||||
#raise ValueError(
|
# #raise ValueError(
|
||||||
# Errors.E021.format(n_actions=self.n_moves, state=failure_state))
|
# # Errors.E021.format(n_actions=self.n_moves, state=failure_state))
|
||||||
|
|
||||||
|
def get_oracle_sequence(self, Example example):
|
||||||
|
cdef Pool mem = Pool()
|
||||||
|
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
||||||
|
assert self.n_moves > 0
|
||||||
|
costs = <float*>mem.alloc(self.n_moves, sizeof(float))
|
||||||
|
is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))
|
||||||
|
|
||||||
|
cdef StateClass state
|
||||||
|
cdef ArcEagerGold gold
|
||||||
|
states, golds, n_steps = self.init_gold_batch([example])
|
||||||
|
state = states[0]
|
||||||
|
gold = golds[0]
|
||||||
|
history = []
|
||||||
|
debug_log = []
|
||||||
|
while not state.is_final():
|
||||||
|
self.set_costs(is_valid, costs, state, gold)
|
||||||
|
for i in range(self.n_moves):
|
||||||
|
if is_valid[i] and costs[i] <= 0:
|
||||||
|
action = self.c[i]
|
||||||
|
history.append(i)
|
||||||
|
s0 = state.S(0)
|
||||||
|
b0 = state.B(0)
|
||||||
|
debug_log.append(" ".join((
|
||||||
|
self.get_class_name(i),
|
||||||
|
"S0=", (example.x[s0].text if s0 >= 0 else "__"),
|
||||||
|
"B0=", (example.x[b0].text if b0 >= 0 else "__"),
|
||||||
|
"S0 head?", str(state.has_head(state.S(0))),
|
||||||
|
)))
|
||||||
|
action.do(state.c, action.label)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print("Actions")
|
||||||
|
for i in range(self.n_moves):
|
||||||
|
print(self.get_class_name(i))
|
||||||
|
print("Gold")
|
||||||
|
for token in example.y:
|
||||||
|
print(token.i, token.text, token.dep_, token.head.text)
|
||||||
|
aligned_heads, aligned_labels = example.get_aligned_parse()
|
||||||
|
print("Aligned heads")
|
||||||
|
for i, head in enumerate(aligned_heads):
|
||||||
|
print(example.x[i], example.x[head] if head is not None else "__")
|
||||||
|
|
||||||
|
print("Predicted tokens")
|
||||||
|
print([(w.i, w.text) for w in example.x])
|
||||||
|
s0 = state.S(0)
|
||||||
|
b0 = state.B(0)
|
||||||
|
debug_log.append(" ".join((
|
||||||
|
"?",
|
||||||
|
"S0=", (example.x[s0].text if s0 >= 0 else "-"),
|
||||||
|
"B0=", (example.x[b0].text if b0 >= 0 else "-"),
|
||||||
|
"S0 head?", str(state.has_head(state.S(0))),
|
||||||
|
)))
|
||||||
|
s0 = state.S(0)
|
||||||
|
b0 = state.B(0)
|
||||||
|
print("\n".join(debug_log))
|
||||||
|
print("Arc is gold B0, S0?", arc_is_gold(&gold.c, b0, s0))
|
||||||
|
print("Arc is gold S0, B0?", arc_is_gold(&gold.c, s0, b0))
|
||||||
|
print("is_head_unknown(s0)", is_head_unknown(&gold.c, s0))
|
||||||
|
print("is_head_unknown(b0)", is_head_unknown(&gold.c, b0))
|
||||||
|
print("b0", b0, "gold.heads[s0]", gold.c.heads[s0])
|
||||||
|
print("Stack", [example.x[i] for i in state.stack])
|
||||||
|
print("Buffer", [example.x[i] for i in state.queue])
|
||||||
|
raise ValueError(Errors.E024)
|
||||||
|
return history
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user