mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-11 17:10:36 +03:00
Restore ArcEager.get_cost function
This commit is contained in:
parent
e90341810c
commit
318a046fb0
|
@ -562,9 +562,6 @@ cdef class ArcEager(TransitionSystem):
|
||||||
def action_types(self):
|
def action_types(self):
|
||||||
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
|
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
|
||||||
|
|
||||||
def get_cost(self, StateClass state, Example gold, action):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def transition(self, StateClass state, action):
|
def transition(self, StateClass state, action):
|
||||||
cdef Transition t = self.lookup_transition(action)
|
cdef Transition t = self.lookup_transition(action)
|
||||||
t.do(state.c, t.label)
|
t.do(state.c, t.label)
|
||||||
|
@ -680,6 +677,18 @@ cdef class ArcEager(TransitionSystem):
|
||||||
else:
|
else:
|
||||||
output[i] = is_valid[self.c[i].move]
|
output[i] = is_valid[self.c[i].move]
|
||||||
|
|
||||||
|
def get_cost(self, StateClass stcls, gold, int i):
|
||||||
|
if not isinstance(gold, ArcEagerGold):
|
||||||
|
raise TypeError("Expected ArcEagerGold")
|
||||||
|
cdef ArcEagerGold gold_ = gold
|
||||||
|
gold_state = gold_.c
|
||||||
|
n_gold = 0
|
||||||
|
if self.c[i].is_valid(stcls.c, self.c[i].label):
|
||||||
|
cost = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
|
||||||
|
else:
|
||||||
|
cost = 9000
|
||||||
|
return cost
|
||||||
|
|
||||||
cdef int set_costs(self, int* is_valid, weight_t* costs,
|
cdef int set_costs(self, int* is_valid, weight_t* costs,
|
||||||
StateClass stcls, gold) except -1:
|
StateClass stcls, gold) except -1:
|
||||||
if not isinstance(gold, ArcEagerGold):
|
if not isinstance(gold, ArcEagerGold):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user