mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-10 15:14:56 +03:00
Start sketching out Split transition implementation
This commit is contained in:
parent
5da7945917
commit
a2f07ab57f
|
@ -35,6 +35,8 @@ cdef enum:
|
|||
|
||||
BREAK
|
||||
|
||||
SPLIT
|
||||
|
||||
N_MOVES
|
||||
|
||||
|
||||
|
@ -44,6 +46,7 @@ MOVE_NAMES[REDUCE] = 'D'
|
|||
MOVE_NAMES[LEFT] = 'L'
|
||||
MOVE_NAMES[RIGHT] = 'R'
|
||||
MOVE_NAMES[BREAK] = 'B'
|
||||
MOVE_NAMES[SPLIT] = 'P'
|
||||
|
||||
|
||||
# Helper functions for the arc-eager oracle
|
||||
|
@ -60,6 +63,10 @@ cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) no
|
|||
if BINARY_COSTS and cost >= 1:
|
||||
return cost
|
||||
cost += Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0
|
||||
# If the token wasn't split before, but gold says it *should* be split,
|
||||
# don't push (split instead)
|
||||
if USE_SPLIT and not stcls.c.was_split[stcls.c.B(0)]:
|
||||
cost += gold.fused[stcls.c.B(0)]
|
||||
return cost
|
||||
|
||||
|
||||
|
@ -112,6 +119,7 @@ cdef bint label_is_gold(const GoldParseC* gold, int head, int child, attr_t labe
|
|||
cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil:
|
||||
return gold.heads[word] == word or not gold.has_dep[word]
|
||||
|
||||
|
||||
cdef class Shift:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
|
@ -124,8 +132,6 @@ cdef class Shift:
|
|||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
#if label != 0:
|
||||
# st.split(st.B(1), label)
|
||||
st.shifted[st.B(0)] = 1
|
||||
st.push()
|
||||
|
||||
|
@ -140,10 +146,41 @@ cdef class Shift:
|
|||
@staticmethod
|
||||
cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
||||
return 0
|
||||
#if gold.fused_tokens[s.B(1)] == label: TODO
|
||||
# return 0
|
||||
#else:
|
||||
# return 1
|
||||
|
||||
|
||||
cdef class Split:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
if not USE_SPLIT:
|
||||
return 0
|
||||
elif st.buffer_length == 0:
|
||||
return 0
|
||||
elif st.is_split[st.B(0)]:
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
st.split(0, label)
|
||||
|
||||
@staticmethod
|
||||
cdef weight_t cost(StateClass st, const GoldParseC* gold, attr_t label) nogil:
|
||||
return Split.move_cost(st, gold) + Split.label_cost(st, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef weight_t move_cost(StateClass st, const GoldParseC* gold) nogil:
|
||||
if gold.fused[st.B(0)]:
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
cdef weight_t label_cost(StateClass st, const GoldParseC* gold, attr_t label) nogil:
|
||||
if gold.fused[st.B(0)] == label:
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
cdef class Reduce:
|
||||
|
@ -247,7 +284,11 @@ cdef class RightArc:
|
|||
|
||||
@staticmethod
|
||||
cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
|
||||
if arc_is_gold(gold, s.S(0), s.B(0)):
|
||||
# If the token wasn't split before, but gold says it *should* be split,
|
||||
# don't right-arc (split instead)
|
||||
if not s.c.was_split[s.c.B(0)] and gold.fused[s.c.B(0)]:
|
||||
return gold.fused[s.c.B(0)]
|
||||
elif arc_is_gold(gold, s.S(0), s.B(0)):
|
||||
return 0
|
||||
elif s.c.shifted[s.B(0)]:
|
||||
return push_cost(s, gold, s.B(0))
|
||||
|
@ -374,7 +415,7 @@ cdef class ArcEager(TransitionSystem):
|
|||
|
||||
property action_types:
|
||||
def __get__(self):
|
||||
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
|
||||
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK, SPLIT)
|
||||
|
||||
def get_cost(self, StateClass state, GoldParse gold, action):
|
||||
cdef Transition t = self.lookup_transition(action)
|
||||
|
@ -513,6 +554,10 @@ cdef class ArcEager(TransitionSystem):
|
|||
t.is_valid = Break.is_valid
|
||||
t.do = Break.transition
|
||||
t.get_cost = Break.cost
|
||||
elif move == SPLIT:
|
||||
t.is_valid = Split.is_valid
|
||||
t.do = Split.transition
|
||||
t.get_cost = Split.cost
|
||||
else:
|
||||
raise Exception(move)
|
||||
return t
|
||||
|
@ -543,6 +588,7 @@ cdef class ArcEager(TransitionSystem):
|
|||
is_valid[LEFT] = LeftArc.is_valid(st, 0)
|
||||
is_valid[RIGHT] = RightArc.is_valid(st, 0)
|
||||
is_valid[BREAK] = Break.is_valid(st, 0)
|
||||
is_valid[SPLIT] = Split.is_valid(st, 0)
|
||||
cdef int i
|
||||
for i in range(self.n_moves):
|
||||
output[i] = is_valid[self.c[i].move]
|
||||
|
@ -561,12 +607,14 @@ cdef class ArcEager(TransitionSystem):
|
|||
move_cost_funcs[LEFT] = LeftArc.move_cost
|
||||
move_cost_funcs[RIGHT] = RightArc.move_cost
|
||||
move_cost_funcs[BREAK] = Break.move_cost
|
||||
move_cost_funcs[SPLIT] = Split.move_cost
|
||||
|
||||
label_cost_funcs[SHIFT] = Shift.label_cost
|
||||
label_cost_funcs[REDUCE] = Reduce.label_cost
|
||||
label_cost_funcs[LEFT] = LeftArc.label_cost
|
||||
label_cost_funcs[RIGHT] = RightArc.label_cost
|
||||
label_cost_funcs[BREAK] = Break.label_cost
|
||||
label_cost_funcs[SPLIT] = Split.label_cost
|
||||
|
||||
cdef attr_t* labels = gold.c.labels
|
||||
cdef int* heads = gold.c.heads
|
||||
|
|
Loading…
Reference in New Issue
Block a user