mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-11 09:00:36 +03:00
Restore ArcEager.get_cost function
This commit is contained in:
parent
e90341810c
commit
318a046fb0
|
@ -562,9 +562,6 @@ cdef class ArcEager(TransitionSystem):
|
|||
def action_types(self):
|
||||
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
|
||||
|
||||
def get_cost(self, StateClass state, Example gold, action):
|
||||
raise NotImplementedError
|
||||
|
||||
def transition(self, StateClass state, action):
|
||||
cdef Transition t = self.lookup_transition(action)
|
||||
t.do(state.c, t.label)
|
||||
|
@ -680,6 +677,18 @@ cdef class ArcEager(TransitionSystem):
|
|||
else:
|
||||
output[i] = is_valid[self.c[i].move]
|
||||
|
||||
def get_cost(self, StateClass stcls, gold, int i):
|
||||
if not isinstance(gold, ArcEagerGold):
|
||||
raise TypeError("Expected ArcEagerGold")
|
||||
cdef ArcEagerGold gold_ = gold
|
||||
gold_state = gold_.c
|
||||
n_gold = 0
|
||||
if self.c[i].is_valid(stcls.c, self.c[i].label):
|
||||
cost = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
|
||||
else:
|
||||
cost = 9000
|
||||
return cost
|
||||
|
||||
cdef int set_costs(self, int* is_valid, weight_t* costs,
|
||||
StateClass stcls, gold) except -1:
|
||||
if not isinstance(gold, ArcEagerGold):
|
||||
|
|
Loading…
Reference in New Issue
Block a user