From bc2a2c81c821d3387c6ee9491556bc54d19c0df4 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 1 Apr 2018 10:41:28 +0200 Subject: [PATCH] Add some methods to ArcEager that make testing easier --- spacy/syntax/arc_eager.pyx | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index ca144bde2..e2e7d5f34 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -370,6 +370,18 @@ cdef class ArcEager(TransitionSystem): def __get__(self): return (SHIFT, REDUCE, LEFT, RIGHT, BREAK) + def get_cost(self, StateClass state, GoldParse gold, action): + cdef Transition t = self.lookup_transition(action) + if not t.is_valid(state.c, t.label): + return 9000 + else: + return t.get_cost(state, &gold.c, t.label) + + def transition(self, StateClass state, action): + cdef Transition t = self.lookup_transition(action) + t.do(state.c, t.label) + return state + def is_gold_parse(self, StateClass state, GoldParse gold): predicted = set() truth = set() @@ -441,7 +453,10 @@ cdef class ArcEager(TransitionSystem): parses.append((prob, parse)) return parses - cdef Transition lookup_transition(self, object name) except *: + cdef Transition lookup_transition(self, object name_or_id) except *: + if isinstance(name_or_id, int): + return self.c[name_or_id] + name = name_or_id if '-' in name: move_str, label_str = name.split('-', 1) label = self.strings[label_str] @@ -461,6 +476,9 @@ cdef class ArcEager(TransitionSystem): else: return MOVE_NAMES[move] + def class_name(self, int i): + return self.move_name(self.c[i].move, self.c[i].label) + cdef Transition init_transition(self, int clas, int move, attr_t label) except *: # TODO: Apparent Cython bug here when we try to use the Transition() # constructor with the function pointers