mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 09:14:32 +03:00
Add some methods to ArcEager that make testing easier
This commit is contained in:
parent
a5f6d69f8a
commit
697bcaa34f
|
@ -366,6 +366,18 @@ cdef class ArcEager(TransitionSystem):
|
|||
def __get__(self):
|
||||
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
|
||||
|
||||
def get_cost(self, StateClass state, GoldParse gold, action):
|
||||
cdef Transition t = self.lookup_transition(action)
|
||||
if not t.is_valid(state.c, t.label):
|
||||
return 9000
|
||||
else:
|
||||
return t.get_cost(state, &gold.c, t.label)
|
||||
|
||||
def transition(self, StateClass state, action):
|
||||
cdef Transition t = self.lookup_transition(action)
|
||||
t.do(state.c, t.label)
|
||||
return state
|
||||
|
||||
def is_gold_parse(self, StateClass state, GoldParse gold):
|
||||
predicted = set()
|
||||
truth = set()
|
||||
|
@ -437,7 +449,10 @@ cdef class ArcEager(TransitionSystem):
|
|||
parses.append((prob, parse))
|
||||
return parses
|
||||
|
||||
cdef Transition lookup_transition(self, object name) except *:
|
||||
cdef Transition lookup_transition(self, object name_or_id) except *:
|
||||
if isinstance(name_or_id, int):
|
||||
return self.c[name_or_id]
|
||||
name = name_or_id
|
||||
if '-' in name:
|
||||
move_str, label_str = name.split('-', 1)
|
||||
label = self.strings[label_str]
|
||||
|
@ -457,6 +472,9 @@ cdef class ArcEager(TransitionSystem):
|
|||
else:
|
||||
return MOVE_NAMES[move]
|
||||
|
||||
def class_name(self, int i):
|
||||
return self.move_name(self.c[i].move, self.c[i].label)
|
||||
|
||||
cdef Transition init_transition(self, int clas, int move, attr_t label) except *:
|
||||
# TODO: Apparent Cython bug here when we try to use the Transition()
|
||||
# constructor with the function pointers
|
||||
|
|
Loading…
Reference in New Issue
Block a user