mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-11 17:10:36 +03:00
Work on parser oracle
This commit is contained in:
parent
914924a68b
commit
c58deb3546
|
@ -82,7 +82,7 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp
|
||||||
gold_i = cand_to_gold[cand_i]
|
gold_i = cand_to_gold[cand_i]
|
||||||
if gold_i is not None: # Alignment found
|
if gold_i is not None: # Alignment found
|
||||||
ref_tok = example.y.c[gold_i]
|
ref_tok = example.y.c[gold_i]
|
||||||
gold_head = gold_to_cand[ref_tok.head + gold_i]
|
gold_head = gold_to_cand[gold_i + ref_tok.head]
|
||||||
if gold_head is not None:
|
if gold_head is not None:
|
||||||
gs.heads[cand_i] = gold_head
|
gs.heads[cand_i] = gold_head
|
||||||
gs.labels[cand_i] = ref_tok.dep
|
gs.labels[cand_i] = ref_tok.dep
|
||||||
|
@ -106,17 +106,17 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp
|
||||||
stack_words = set()
|
stack_words = set()
|
||||||
for i in range(stcls.stack_depth()):
|
for i in range(stcls.stack_depth()):
|
||||||
s_i = stcls.S(i)
|
s_i = stcls.S(i)
|
||||||
head = s_i + gs.heads[s_i]
|
head = gs.heads[s_i]
|
||||||
gs.n_kids_in_stack[head] += 1
|
gs.n_kids_in_stack[head] += 1
|
||||||
stack_words.add(s_i)
|
stack_words.add(s_i)
|
||||||
buffer_words = set()
|
buffer_words = set()
|
||||||
for i in range(stcls.buffer_length()):
|
for i in range(stcls.buffer_length()):
|
||||||
b_i = stcls.B(i)
|
b_i = stcls.B(i)
|
||||||
head = b_i + gs.heads[b_i]
|
head = gs.heads[b_i]
|
||||||
gs.n_kids_in_buffer[head] += 1
|
gs.n_kids_in_buffer[head] += 1
|
||||||
buffer_words.add(b_i)
|
buffer_words.add(b_i)
|
||||||
for i in range(gs.length):
|
for i in range(gs.length):
|
||||||
head = i + gs.heads[i]
|
head = gs.heads[i]
|
||||||
if head in stack_words:
|
if head in stack_words:
|
||||||
gs.state_bits[i] = set_state_flag(
|
gs.state_bits[i] = set_state_flag(
|
||||||
gs.state_bits[i],
|
gs.state_bits[i],
|
||||||
|
@ -142,6 +142,58 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp
|
||||||
return gs
|
return gs
|
||||||
|
|
||||||
|
|
||||||
|
cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) except *:
|
||||||
|
for i in range(gs.length):
|
||||||
|
gs.state_bits[i] = set_state_flag(
|
||||||
|
gs.state_bits[i],
|
||||||
|
HEAD_IN_BUFFER,
|
||||||
|
0
|
||||||
|
)
|
||||||
|
gs.state_bits[i] = set_state_flag(
|
||||||
|
gs.state_bits[i],
|
||||||
|
HEAD_IN_STACK,
|
||||||
|
0
|
||||||
|
)
|
||||||
|
gs.n_kids_in_stack[i] = 0
|
||||||
|
gs.n_kids_in_buffer[i] = 0
|
||||||
|
stack_words = set()
|
||||||
|
for i in range(stcls.stack_depth()):
|
||||||
|
s_i = stcls.S(i)
|
||||||
|
head = gs.heads[s_i]
|
||||||
|
gs.n_kids_in_stack[head] += 1
|
||||||
|
stack_words.add(s_i)
|
||||||
|
buffer_words = set()
|
||||||
|
for i in range(stcls.buffer_length()):
|
||||||
|
b_i = stcls.B(i)
|
||||||
|
head = gs.heads[b_i]
|
||||||
|
gs.n_kids_in_buffer[head] += 1
|
||||||
|
buffer_words.add(b_i)
|
||||||
|
for i in range(gs.length):
|
||||||
|
head = gs.heads[i]
|
||||||
|
if head in stack_words:
|
||||||
|
gs.state_bits[i] = set_state_flag(
|
||||||
|
gs.state_bits[i],
|
||||||
|
HEAD_IN_STACK,
|
||||||
|
1
|
||||||
|
)
|
||||||
|
gs.state_bits[i] = set_state_flag(
|
||||||
|
gs.state_bits[i],
|
||||||
|
HEAD_IN_BUFFER,
|
||||||
|
0
|
||||||
|
)
|
||||||
|
elif head in buffer_words:
|
||||||
|
gs.state_bits[i] = set_state_flag(
|
||||||
|
gs.state_bits[i],
|
||||||
|
HEAD_IN_STACK,
|
||||||
|
0
|
||||||
|
)
|
||||||
|
gs.state_bits[i] = set_state_flag(
|
||||||
|
gs.state_bits[i],
|
||||||
|
HEAD_IN_BUFFER,
|
||||||
|
1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
cdef class ArcEagerGold:
|
cdef class ArcEagerGold:
|
||||||
cdef GoldParseStateC c
|
cdef GoldParseStateC c
|
||||||
cdef Pool mem
|
cdef Pool mem
|
||||||
|
@ -150,6 +202,9 @@ cdef class ArcEagerGold:
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
self.c = create_gold_state(self.mem, stcls, example)
|
self.c = create_gold_state(self.mem, stcls, example)
|
||||||
|
|
||||||
|
def update(self, StateClass stcls):
|
||||||
|
update_gold_state(&self.c, stcls)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
cdef int check_state_gold(char state_bits, char flag) nogil:
|
cdef int check_state_gold(char state_bits, char flag) nogil:
|
||||||
|
@ -319,22 +374,27 @@ cdef class LeftArc:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef inline weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
|
cdef inline weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
|
||||||
gold = <const GoldParseStateC*>_gold
|
gold = <const GoldParseStateC*>_gold
|
||||||
return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
|
return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef inline weight_t move_cost(StateClass s, const void* _gold) nogil:
|
cdef inline weight_t move_cost(StateClass s, const GoldParseStateC* gold) nogil:
|
||||||
gold = <const GoldParseStateC*>_gold
|
cdef weight_t cost = 0
|
||||||
if arc_is_gold(gold, s.S(0), s.B(0)):
|
s0 = s.S(0)
|
||||||
return 0
|
b0 = s.B(0)
|
||||||
elif s.c.shifted[s.B(0)]:
|
if arc_is_gold(gold, b0, s0):
|
||||||
return push_cost(s, gold, s.B(0))
|
# Have a negative cost if we 'recover' from the wrong dependency
|
||||||
|
return 0 if not s.has_head(s0) else -1
|
||||||
else:
|
else:
|
||||||
return push_cost(s, gold, s.B(0)) + arc_cost(s, gold, s.S(0), s.B(0))
|
# Account for deps we might lose between S0 and stack
|
||||||
|
if not s.has_head(s0):
|
||||||
|
cost += gold.n_kids_in_stack[s0]
|
||||||
|
if is_head_in_buffer(gold, s0):
|
||||||
|
cost += 1
|
||||||
|
return cost + pop_cost(s, gold, s.S(0)) + arc_cost(s, gold, s.B(0), s.S(0))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef weight_t label_cost(StateClass s, const void* _gold, attr_t label) nogil:
|
cdef inline weight_t label_cost(StateClass s, const GoldParseStateC* gold, attr_t label) nogil:
|
||||||
gold = <const GoldParseStateC*>_gold
|
return arc_is_gold(gold, s.B(0), s.S(0)) and not label_is_gold(gold, s.B(0), s.S(0), label)
|
||||||
return arc_is_gold(gold, s.S(0), s.B(0)) and not label_is_gold(gold, s.S(0), s.B(0), label)
|
|
||||||
|
|
||||||
|
|
||||||
cdef class RightArc:
|
cdef class RightArc:
|
||||||
|
@ -622,42 +682,17 @@ cdef class ArcEager(TransitionSystem):
|
||||||
|
|
||||||
cdef int set_costs(self, int* is_valid, weight_t* costs,
|
cdef int set_costs(self, int* is_valid, weight_t* costs,
|
||||||
StateClass stcls, gold) except -1:
|
StateClass stcls, gold) except -1:
|
||||||
gold_state = (<ArcEagerGold>gold).c
|
if not isinstance(gold, ArcEagerGold):
|
||||||
cdef int i, move
|
raise TypeError("Expected ArcEagerGold")
|
||||||
cdef attr_t label
|
cdef ArcEagerGold gold_ = gold
|
||||||
cdef label_cost_func_t[N_MOVES] label_cost_funcs
|
gold_.update(stcls)
|
||||||
cdef move_cost_func_t[N_MOVES] move_cost_funcs
|
gold_state = gold_.c
|
||||||
cdef weight_t[N_MOVES] move_costs
|
|
||||||
for i in range(N_MOVES):
|
|
||||||
move_costs[i] = 9000
|
|
||||||
move_cost_funcs[SHIFT] = Shift.move_cost
|
|
||||||
move_cost_funcs[REDUCE] = Reduce.move_cost
|
|
||||||
move_cost_funcs[LEFT] = LeftArc.move_cost
|
|
||||||
move_cost_funcs[RIGHT] = RightArc.move_cost
|
|
||||||
move_cost_funcs[BREAK] = Break.move_cost
|
|
||||||
|
|
||||||
label_cost_funcs[SHIFT] = Shift.label_cost
|
|
||||||
label_cost_funcs[REDUCE] = Reduce.label_cost
|
|
||||||
label_cost_funcs[LEFT] = LeftArc.label_cost
|
|
||||||
label_cost_funcs[RIGHT] = RightArc.label_cost
|
|
||||||
label_cost_funcs[BREAK] = Break.label_cost
|
|
||||||
|
|
||||||
cdef attr_t* labels = gold_state.labels
|
|
||||||
cdef int32_t* heads = gold_state.heads
|
|
||||||
|
|
||||||
n_gold = 0
|
n_gold = 0
|
||||||
for i in range(self.n_moves):
|
for i in range(self.n_moves):
|
||||||
if self.c[i].is_valid(stcls.c, self.c[i].label):
|
if self.c[i].is_valid(stcls.c, self.c[i].label):
|
||||||
is_valid[i] = True
|
is_valid[i] = True
|
||||||
move = self.c[i].move
|
costs[i] = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
|
||||||
label = self.c[i].label
|
n_gold += 1
|
||||||
if move_costs[move] == 9000:
|
|
||||||
move_costs[move] = move_cost_funcs[move](stcls, &gold_state)
|
|
||||||
move_cost = move_costs[move]
|
|
||||||
label_cost = label_cost_funcs[move](stcls, &gold_state, label)
|
|
||||||
costs[i] = move_cost + label_cost
|
|
||||||
n_gold += costs[i] <= 0
|
|
||||||
print(move, label, costs[i])
|
|
||||||
else:
|
else:
|
||||||
is_valid[i] = False
|
is_valid[i] = False
|
||||||
costs[i] = 9000
|
costs[i] = 9000
|
||||||
|
|
Loading…
Reference in New Issue
Block a user