Add some methods to ArcEager that make testing easier

This commit is contained in:
Matthew Honnibal 2018-04-01 10:41:28 +02:00
parent a5f6d69f8a
commit 697bcaa34f

View File

@ -366,6 +366,18 @@ cdef class ArcEager(TransitionSystem):
def __get__(self):
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
def get_cost(self, StateClass state, GoldParse gold, action):
cdef Transition t = self.lookup_transition(action)
if not t.is_valid(state.c, t.label):
return 9000
else:
return t.get_cost(state, &gold.c, t.label)
def transition(self, StateClass state, action):
cdef Transition t = self.lookup_transition(action)
t.do(state.c, t.label)
return state
def is_gold_parse(self, StateClass state, GoldParse gold):
predicted = set()
truth = set()
@ -437,7 +449,10 @@ cdef class ArcEager(TransitionSystem):
parses.append((prob, parse))
return parses
cdef Transition lookup_transition(self, object name) except *:
cdef Transition lookup_transition(self, object name_or_id) except *:
if isinstance(name_or_id, int):
return self.c[name_or_id]
name = name_or_id
if '-' in name:
move_str, label_str = name.split('-', 1)
label = self.strings[label_str]
@ -457,6 +472,9 @@ cdef class ArcEager(TransitionSystem):
else:
return MOVE_NAMES[move]
def class_name(self, int i):
return self.move_name(self.c[i].move, self.c[i].label)
cdef Transition init_transition(self, int clas, int move, attr_t label) except *:
# TODO: Apparent Cython bug here when we try to use the Transition()
# constructor with the function pointers