Start sketching out Split transition implementation

This commit is contained in:
Matthew Honnibal 2018-04-01 13:45:41 +02:00
parent 5da7945917
commit a2f07ab57f

View File

@ -35,6 +35,8 @@ cdef enum:
BREAK BREAK
SPLIT
N_MOVES N_MOVES
@ -44,6 +46,7 @@ MOVE_NAMES[REDUCE] = 'D'
MOVE_NAMES[LEFT] = 'L' MOVE_NAMES[LEFT] = 'L'
MOVE_NAMES[RIGHT] = 'R' MOVE_NAMES[RIGHT] = 'R'
MOVE_NAMES[BREAK] = 'B' MOVE_NAMES[BREAK] = 'B'
MOVE_NAMES[SPLIT] = 'P'
# Helper functions for the arc-eager oracle # Helper functions for the arc-eager oracle
@ -60,6 +63,10 @@ cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) no
if BINARY_COSTS and cost >= 1: if BINARY_COSTS and cost >= 1:
return cost return cost
cost += Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0 cost += Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0
# If the token wasn't split before, but gold says it *should* be split,
# don't push (split instead)
if USE_SPLIT and not stcls.c.was_split[stcls.c.B(0)]:
cost += gold.fused[stcls.c.B(0)]
return cost return cost
@ -112,6 +119,7 @@ cdef bint label_is_gold(const GoldParseC* gold, int head, int child, attr_t labe
cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil: cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil:
return gold.heads[word] == word or not gold.has_dep[word] return gold.heads[word] == word or not gold.has_dep[word]
cdef class Shift: cdef class Shift:
@staticmethod @staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil: cdef bint is_valid(const StateC* st, attr_t label) nogil:
@ -124,8 +132,6 @@ cdef class Shift:
@staticmethod @staticmethod
cdef int transition(StateC* st, attr_t label) nogil: cdef int transition(StateC* st, attr_t label) nogil:
#if label != 0:
# st.split(st.B(1), label)
st.shifted[st.B(0)] = 1 st.shifted[st.B(0)] = 1
st.push() st.push()
@ -140,10 +146,41 @@ cdef class Shift:
@staticmethod @staticmethod
cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
return 0 return 0
#if gold.fused_tokens[s.B(1)] == label: TODO
# return 0
#else: cdef class Split:
# return 1 @staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
if not USE_SPLIT:
return 0
elif st.buffer_length == 0:
return 0
elif st.is_split[st.B(0)]:
return 0
else:
return 1
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
st.split(0, label)
@staticmethod
cdef weight_t cost(StateClass st, const GoldParseC* gold, attr_t label) nogil:
return Split.move_cost(st, gold) + Split.label_cost(st, gold, label)
@staticmethod
cdef weight_t move_cost(StateClass st, const GoldParseC* gold) nogil:
if gold.fused[st.B(0)]:
return 0
else:
return 1
@staticmethod
cdef weight_t label_cost(StateClass st, const GoldParseC* gold, attr_t label) nogil:
if gold.fused[st.B(0)] == label:
return 0
else:
return 1
cdef class Reduce: cdef class Reduce:
@ -247,7 +284,11 @@ cdef class RightArc:
@staticmethod @staticmethod
cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil: cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
if arc_is_gold(gold, s.S(0), s.B(0)): # If the token wasn't split before, but gold says it *should* be split,
# don't right-arc (split instead)
if not s.c.was_split[s.c.B(0)] and gold.fused[s.c.B(0)]:
return gold.fused[s.c.B(0)]
elif arc_is_gold(gold, s.S(0), s.B(0)):
return 0 return 0
elif s.c.shifted[s.B(0)]: elif s.c.shifted[s.B(0)]:
return push_cost(s, gold, s.B(0)) return push_cost(s, gold, s.B(0))
@ -374,7 +415,7 @@ cdef class ArcEager(TransitionSystem):
property action_types: property action_types:
def __get__(self): def __get__(self):
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK) return (SHIFT, REDUCE, LEFT, RIGHT, BREAK, SPLIT)
def get_cost(self, StateClass state, GoldParse gold, action): def get_cost(self, StateClass state, GoldParse gold, action):
cdef Transition t = self.lookup_transition(action) cdef Transition t = self.lookup_transition(action)
@ -513,6 +554,10 @@ cdef class ArcEager(TransitionSystem):
t.is_valid = Break.is_valid t.is_valid = Break.is_valid
t.do = Break.transition t.do = Break.transition
t.get_cost = Break.cost t.get_cost = Break.cost
elif move == SPLIT:
t.is_valid = Split.is_valid
t.do = Split.transition
t.get_cost = Split.cost
else: else:
raise Exception(move) raise Exception(move)
return t return t
@ -543,6 +588,7 @@ cdef class ArcEager(TransitionSystem):
is_valid[LEFT] = LeftArc.is_valid(st, 0) is_valid[LEFT] = LeftArc.is_valid(st, 0)
is_valid[RIGHT] = RightArc.is_valid(st, 0) is_valid[RIGHT] = RightArc.is_valid(st, 0)
is_valid[BREAK] = Break.is_valid(st, 0) is_valid[BREAK] = Break.is_valid(st, 0)
is_valid[SPLIT] = Split.is_valid(st, 0)
cdef int i cdef int i
for i in range(self.n_moves): for i in range(self.n_moves):
output[i] = is_valid[self.c[i].move] output[i] = is_valid[self.c[i].move]
@ -561,12 +607,14 @@ cdef class ArcEager(TransitionSystem):
move_cost_funcs[LEFT] = LeftArc.move_cost move_cost_funcs[LEFT] = LeftArc.move_cost
move_cost_funcs[RIGHT] = RightArc.move_cost move_cost_funcs[RIGHT] = RightArc.move_cost
move_cost_funcs[BREAK] = Break.move_cost move_cost_funcs[BREAK] = Break.move_cost
move_cost_funcs[SPLIT] = Split.move_cost
label_cost_funcs[SHIFT] = Shift.label_cost label_cost_funcs[SHIFT] = Shift.label_cost
label_cost_funcs[REDUCE] = Reduce.label_cost label_cost_funcs[REDUCE] = Reduce.label_cost
label_cost_funcs[LEFT] = LeftArc.label_cost label_cost_funcs[LEFT] = LeftArc.label_cost
label_cost_funcs[RIGHT] = RightArc.label_cost label_cost_funcs[RIGHT] = RightArc.label_cost
label_cost_funcs[BREAK] = Break.label_cost label_cost_funcs[BREAK] = Break.label_cost
label_cost_funcs[SPLIT] = Split.label_cost
cdef attr_t* labels = gold.c.labels cdef attr_t* labels = gold.c.labels
cdef int* heads = gold.c.heads cdef int* heads = gold.c.heads