From 318a046fb094d42e4490c05d8a723696f878c30b Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 21 Jun 2020 01:11:08 +0200 Subject: [PATCH] Restore ArcEager.get_cost function --- spacy/syntax/arc_eager.pyx | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 13879d898..c7ecbceea 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -562,9 +562,6 @@ cdef class ArcEager(TransitionSystem): def action_types(self): return (SHIFT, REDUCE, LEFT, RIGHT, BREAK) - def get_cost(self, StateClass state, Example gold, action): - raise NotImplementedError - def transition(self, StateClass state, action): cdef Transition t = self.lookup_transition(action) t.do(state.c, t.label) @@ -679,6 +676,18 @@ cdef class ArcEager(TransitionSystem): output[i] = self.c[i].is_valid(st, self.c[i].label) else: output[i] = is_valid[self.c[i].move] + + def get_cost(self, StateClass stcls, gold, int i): + if not isinstance(gold, ArcEagerGold): + raise TypeError("Expected ArcEagerGold") + cdef ArcEagerGold gold_ = gold + gold_state = gold_.c + n_gold = 0 + if self.c[i].is_valid(stcls.c, self.c[i].label): + cost = self.c[i].get_cost(stcls, &gold_state, self.c[i].label) + else: + cost = 9000 + return cost cdef int set_costs(self, int* is_valid, weight_t* costs, StateClass stcls, gold) except -1: