Restore ArcEager.get_cost function

This commit is contained in:
Matthew Honnibal 2020-06-21 01:11:08 +02:00
parent e90341810c
commit 318a046fb0

View File

@ -562,9 +562,6 @@ cdef class ArcEager(TransitionSystem):
def action_types(self):
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
def get_cost(self, StateClass state, Example gold, action):
raise NotImplementedError
def transition(self, StateClass state, action):
cdef Transition t = self.lookup_transition(action)
t.do(state.c, t.label)
@ -679,6 +676,18 @@ cdef class ArcEager(TransitionSystem):
output[i] = self.c[i].is_valid(st, self.c[i].label)
else:
output[i] = is_valid[self.c[i].move]
def get_cost(self, StateClass stcls, gold, int i):
if not isinstance(gold, ArcEagerGold):
raise TypeError("Expected ArcEagerGold")
cdef ArcEagerGold gold_ = gold
gold_state = gold_.c
n_gold = 0
if self.c[i].is_valid(stcls.c, self.c[i].label):
cost = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
else:
cost = 9000
return cost
cdef int set_costs(self, int* is_valid, weight_t* costs,
StateClass stcls, gold) except -1: