* Cost functions now take StateClass argument, instead of State*.

This commit is contained in:
Matthew Honnibal 2015-06-10 00:40:43 +02:00
parent e0cf61f591
commit 4b98b3e9c8
5 changed files with 65 additions and 72 deletions

View File

@ -53,9 +53,7 @@ MOVE_NAMES[BREAK] = 'B'
# Helper functions for the arc-eager oracle
cdef int push_cost(const State* st, const GoldParseC* gold, int target) except -1:
cdef StateClass stcls = StateClass(st.sent_len)
stcls.from_struct(st)
cdef int push_cost(StateClass stcls, const GoldParseC* gold, int target) except -1:
cdef int cost = 0
cdef int i, S_i
for i in range(stcls.stack_depth()):
@ -64,13 +62,11 @@ cdef int push_cost(const State* st, const GoldParseC* gold, int target) except -
cost += 1
if gold.heads[S_i] == target and (NON_MONOTONIC or not stcls.has_head(S_i)):
cost += 1
cost += Break.is_valid(stcls, -1) and Break.move_cost(st, gold) == 0
cost += Break.is_valid(stcls, -1) and Break.move_cost(stcls, gold) == 0
return cost
cdef int pop_cost(const State* st, const GoldParseC* gold, int target) except -1:
cdef StateClass stcls = StateClass(st.sent_len)
stcls.from_struct(st)
cdef int pop_cost(StateClass stcls, const GoldParseC* gold, int target) except -1:
cdef int cost = 0
cdef int i, B_i
for i in range(stcls.buffer_length()):
@ -81,9 +77,7 @@ cdef int pop_cost(const State* st, const GoldParseC* gold, int target) except -1
break
return cost
cdef int arc_cost(const State* st, const GoldParseC* gold, int head, int child) except -1:
cdef StateClass stcls = StateClass(st.sent_len)
stcls.from_struct(st)
cdef int arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) except -1:
if arc_is_gold(gold, head, child):
return 0
elif stcls.H(child) == gold.heads[child]:
@ -133,15 +127,15 @@ cdef class Shift:
push_stack(state)
@staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
return Shift.move_cost(s, gold) + Shift.label_cost(s, gold, label)
cdef int cost(StateClass st, const GoldParseC* gold, int label) except -1:
return Shift.move_cost(st, gold) + Shift.label_cost(st, gold, label)
@staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
return push_cost(s, gold, s.i)
cdef int move_cost(StateClass s, const GoldParseC* gold) except -1:
return push_cost(s, gold, s.B(0))
@staticmethod
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) except -1:
return 0
@ -160,18 +154,18 @@ cdef class Reduce:
pop_stack(state)
@staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label)
@staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
cdef int move_cost(StateClass s, const GoldParseC* gold) except -1:
if NON_MONOTONIC:
return pop_cost(s, gold, s.stack[0])
return pop_cost(s, gold, s.S(0))
else:
return children_in_buffer(s, s.stack[0], gold.heads)
return children_in_buffer(s, s.S(0), gold.heads)
@staticmethod
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) except -1:
return 0
@ -193,19 +187,19 @@ cdef class LeftArc:
pop_stack(state)
@staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label)
@staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
if arc_is_gold(gold, s.i, s.stack[0]):
cdef int move_cost(StateClass s, const GoldParseC* gold) except -1:
if arc_is_gold(gold, s.B(0), s.S(0)):
return 0
else:
return pop_cost(s, gold, s.stack[0]) + arc_cost(s, gold, s.i, s.stack[0])
return pop_cost(s, gold, s.S(0)) + arc_cost(s, gold, s.B(0), s.S(0))
@staticmethod
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
return arc_is_gold(gold, s.i, s.stack[0]) and not label_is_gold(gold, s.i, s.stack[0], label)
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) except -1:
return arc_is_gold(gold, s.B(0), s.S(0)) and not label_is_gold(gold, s.B(0), s.S(0), label)
cdef class RightArc:
@ -219,19 +213,19 @@ cdef class RightArc:
push_stack(state)
@staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
@staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
if arc_is_gold(gold, s.stack[0], s.i):
cdef int move_cost(StateClass s, const GoldParseC* gold) except -1:
if arc_is_gold(gold, s.S(0), s.B(0)):
return 0
else:
return push_cost(s, gold, s.i) + arc_cost(s, gold, s.stack[0], s.i)
return push_cost(s, gold, s.B(0)) + arc_cost(s, gold, s.S(0), s.B(0))
@staticmethod
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
return arc_is_gold(gold, s.stack[0], s.i) and not label_is_gold(gold, s.stack[0], s.i, label)
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) except -1:
return arc_is_gold(gold, s.S(0), s.B(0)) and not label_is_gold(gold, s.S(0), s.B(0), label)
cdef class Break:
@ -271,21 +265,25 @@ cdef class Break:
state.stack_len -= 1
@staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
return Break.move_cost(s, gold) + Break.label_cost(s, gold, label)
@staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
cdef int move_cost(StateClass s, const GoldParseC* gold) except -1:
# When we break, we Reduce all of the words on the stack.
cdef int cost = 0
# Number of deps between S0...Sn and N0...Nn
for i in range(s.i, s.sent_len):
cost += children_in_stack(s, i, gold.heads)
cost += head_in_stack(s, i, gold.heads)
cdef int i, B_i, S_i
for i in range(s.buffer_length()):
B_i = s.B(i)
for j in range(s.stack_depth()):
S_i = s.S(j)
cost += gold.heads[B_i] == S_i
cost += gold.heads[S_i] == B_i
return cost
@staticmethod
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) except -1:
return 0
@ -389,7 +387,7 @@ cdef class ArcEager(TransitionSystem):
for i in range(self.n_moves):
output[i] = is_valid[self.c[i].move]
cdef int set_costs(self, int* output, const State* s, GoldParse gold) except -1:
cdef int set_costs(self, int* output, StateClass stcls, GoldParse gold) except -1:
cdef int i, move, label
cdef label_cost_func_t[N_MOVES] label_cost_funcs
cdef move_cost_func_t[N_MOVES] move_cost_funcs
@ -411,8 +409,6 @@ cdef class ArcEager(TransitionSystem):
cdef int* labels = gold.c.labels
cdef int* heads = gold.c.heads
cdef StateClass stcls = StateClass(s.sent_len)
stcls.from_struct(s)
self.set_valid(self._is_valid, stcls)
for i in range(self.n_moves):
if not self._is_valid[i]:
@ -421,8 +417,8 @@ cdef class ArcEager(TransitionSystem):
move = self.c[i].move
label = self.c[i].label
if move_costs[move] == -1:
move_costs[move] = move_cost_funcs[move](s, &gold.c)
output[i] = move_costs[move] + label_cost_funcs[move](s, &gold.c, label)
move_costs[move] = move_cost_funcs[move](stcls, &gold.c)
output[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label)
cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *:
cdef bint[N_MOVES] is_valid

View File

@ -36,18 +36,14 @@ MOVE_NAMES[OUT] = 'O'
cdef do_func_t[N_MOVES] do_funcs
cdef bint entity_is_open(const State *s) except -1:
return s.ents_len >= 1 and s.ent.end == 0
cdef bint _entity_is_sunk(const State *s, Transition* golds) except -1:
if not entity_is_open(s):
cdef bint _entity_is_sunk(StateClass st, Transition* golds) except -1:
if not st.entity_is_open():
return False
cdef const Transition* gold = &golds[s.ent.start]
cdef const Transition* gold = &golds[st.E(0)]
if gold.move != BEGIN and gold.move != UNIT:
return True
elif gold.label != s.ent.label:
elif gold.label != st.E_(0).ent_type:
return True
else:
return False
@ -166,7 +162,7 @@ cdef class Missing:
raise NotImplementedError
@staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
return 9000
@ -187,7 +183,7 @@ cdef class Begin:
s.i += 1
@staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
cdef int g_act = gold.ner[s.i].move
cdef int g_tag = gold.ner[s.i].label
@ -216,7 +212,7 @@ cdef class In:
s.i += 1
@staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
move = IN
cdef int next_act = gold.ner[s.i+1].move if s.i < s.sent_len else OUT
cdef int g_act = gold.ner[s.i].move
@ -257,7 +253,7 @@ cdef class Last:
s.i += 1
@staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
move = LAST
cdef int g_act = gold.ner[s.i].move
@ -301,7 +297,7 @@ cdef class Unit:
s.i += 1
@staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
cdef int g_act = gold.ner[s.i].move
cdef int g_tag = gold.ner[s.i].label
@ -329,7 +325,7 @@ cdef class Out:
s.i += 1
@staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
cdef int g_act = gold.ner[s.i].move
cdef int g_tag = gold.ner[s.i].label

View File

@ -144,8 +144,8 @@ cdef class Parser:
_new_fill_context(context, stcls)
scores = self.model.score(context)
guess = self.moves.best_valid(scores, stcls)
best = self.moves.best_gold(scores, state, gold)
cost = guess.get_cost(state, &gold.c, guess.label)
best = self.moves.best_gold(scores, stcls, gold)
cost = guess.get_cost(stcls, &gold.c, guess.label)
self.model.update(context, guess.clas, best.clas, cost)
guess.do(state, guess.label)
loss += cost
@ -191,7 +191,7 @@ cdef class Parser:
if gold is not None:
for i in range(beam.size):
state = <State*>beam.at(i)
self.moves.set_costs(beam.costs[i], state, gold)
self.moves.set_costs(beam.costs[i], stcls, gold)
if follow_gold:
for j in range(self.moves.n_moves):
beam.is_valid[i][j] *= beam.costs[i][j] == 0

View File

@ -18,13 +18,13 @@ cdef struct Transition:
weight_t score
bint (*is_valid)(StateClass state, int label) except -1
int (*get_cost)(const State* state, const GoldParseC* gold, int label) except -1
int (*get_cost)(StateClass state, const GoldParseC* gold, int label) except -1
int (*do)(State* state, int label) except -1
ctypedef int (*get_cost_func_t)(const State* state, const GoldParseC* gold, int label) except -1
ctypedef int (*move_cost_func_t)(const State* state, const GoldParseC* gold) except -1
ctypedef int (*label_cost_func_t)(const State* state, const GoldParseC* gold, int label) except -1
ctypedef int (*get_cost_func_t)(StateClass state, const GoldParseC* gold, int label) except -1
ctypedef int (*move_cost_func_t)(StateClass state, const GoldParseC* gold) except -1
ctypedef int (*label_cost_func_t)(StateClass state, const GoldParseC* gold, int label) except -1
ctypedef int (*do_func_t)(State* state, int label) except -1
@ -47,9 +47,9 @@ cdef class TransitionSystem:
cdef int set_valid(self, bint* output, StateClass state) except -1
cdef int set_costs(self, int* output, const State* state, GoldParse gold) except -1
cdef int set_costs(self, int* output, StateClass state, GoldParse gold) except -1
cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *
cdef Transition best_gold(self, const weight_t* scores, State* state,
cdef Transition best_gold(self, const weight_t* scores, StateClass state,
GoldParse gold) except *

View File

@ -50,21 +50,22 @@ cdef class TransitionSystem:
cdef int set_valid(self, bint* output, StateClass state) except -1:
raise NotImplementedError
cdef int set_costs(self, int* output, const State* s, GoldParse gold) except -1:
cdef int set_costs(self, int* output, StateClass stcls, GoldParse gold) except -1:
cdef int i
for i in range(self.n_moves):
output[i] = self.c[i].get_cost(s, &gold.c, self.c[i].label)
if self.c[i].is_valid(stcls, self.c[i].label):
output[i] = self.c[i].get_cost(stcls, &gold.c, self.c[i].label)
else:
output[i] = 9000
cdef Transition best_gold(self, const weight_t* scores, const State* s,
cdef Transition best_gold(self, const weight_t* scores, StateClass stcls,
GoldParse gold) except *:
cdef StateClass stcls = StateClass(s.sent_len)
stcls.from_struct(s)
cdef Transition best
cdef weight_t score = MIN_SCORE
cdef int i
for i in range(self.n_moves):
if self.c[i].is_valid(stcls, self.c[i].label):
cost = self.c[i].get_cost(s, &gold.c, self.c[i].label)
cost = self.c[i].get_cost(stcls, &gold.c, self.c[i].label)
if scores[i] > score and cost == 0:
best = self.c[i]
score = scores[i]