mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-10 15:14:56 +03:00
Start sketching out Split transition implementation
This commit is contained in:
parent
5da7945917
commit
a2f07ab57f
|
@ -35,6 +35,8 @@ cdef enum:
|
||||||
|
|
||||||
BREAK
|
BREAK
|
||||||
|
|
||||||
|
SPLIT
|
||||||
|
|
||||||
N_MOVES
|
N_MOVES
|
||||||
|
|
||||||
|
|
||||||
|
@ -44,6 +46,7 @@ MOVE_NAMES[REDUCE] = 'D'
|
||||||
MOVE_NAMES[LEFT] = 'L'
|
MOVE_NAMES[LEFT] = 'L'
|
||||||
MOVE_NAMES[RIGHT] = 'R'
|
MOVE_NAMES[RIGHT] = 'R'
|
||||||
MOVE_NAMES[BREAK] = 'B'
|
MOVE_NAMES[BREAK] = 'B'
|
||||||
|
MOVE_NAMES[SPLIT] = 'P'
|
||||||
|
|
||||||
|
|
||||||
# Helper functions for the arc-eager oracle
|
# Helper functions for the arc-eager oracle
|
||||||
|
@ -60,6 +63,10 @@ cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) no
|
||||||
if BINARY_COSTS and cost >= 1:
|
if BINARY_COSTS and cost >= 1:
|
||||||
return cost
|
return cost
|
||||||
cost += Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0
|
cost += Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0
|
||||||
|
# If the token wasn't split before, but gold says it *should* be split,
|
||||||
|
# don't push (split instead)
|
||||||
|
if USE_SPLIT and not stcls.c.was_split[stcls.c.B(0)]:
|
||||||
|
cost += gold.fused[stcls.c.B(0)]
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
|
|
||||||
|
@ -112,6 +119,7 @@ cdef bint label_is_gold(const GoldParseC* gold, int head, int child, attr_t labe
|
||||||
cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil:
|
cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil:
|
||||||
return gold.heads[word] == word or not gold.has_dep[word]
|
return gold.heads[word] == word or not gold.has_dep[word]
|
||||||
|
|
||||||
|
|
||||||
cdef class Shift:
|
cdef class Shift:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||||
|
@ -124,8 +132,6 @@ cdef class Shift:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(StateC* st, attr_t label) nogil:
|
cdef int transition(StateC* st, attr_t label) nogil:
|
||||||
#if label != 0:
|
|
||||||
# st.split(st.B(1), label)
|
|
||||||
st.shifted[st.B(0)] = 1
|
st.shifted[st.B(0)] = 1
|
||||||
st.push()
|
st.push()
|
||||||
|
|
||||||
|
@ -140,10 +146,41 @@ cdef class Shift:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
||||||
return 0
|
return 0
|
||||||
#if gold.fused_tokens[s.B(1)] == label: TODO
|
|
||||||
# return 0
|
|
||||||
#else:
|
cdef class Split:
|
||||||
# return 1
|
@staticmethod
|
||||||
|
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||||
|
if not USE_SPLIT:
|
||||||
|
return 0
|
||||||
|
elif st.buffer_length == 0:
|
||||||
|
return 0
|
||||||
|
elif st.is_split[st.B(0)]:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
cdef int transition(StateC* st, attr_t label) nogil:
|
||||||
|
st.split(0, label)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
cdef weight_t cost(StateClass st, const GoldParseC* gold, attr_t label) nogil:
|
||||||
|
return Split.move_cost(st, gold) + Split.label_cost(st, gold, label)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
cdef weight_t move_cost(StateClass st, const GoldParseC* gold) nogil:
|
||||||
|
if gold.fused[st.B(0)]:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
cdef weight_t label_cost(StateClass st, const GoldParseC* gold, attr_t label) nogil:
|
||||||
|
if gold.fused[st.B(0)] == label:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
cdef class Reduce:
|
cdef class Reduce:
|
||||||
|
@ -247,7 +284,11 @@ cdef class RightArc:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
|
cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
|
||||||
if arc_is_gold(gold, s.S(0), s.B(0)):
|
# If the token wasn't split before, but gold says it *should* be split,
|
||||||
|
# don't right-arc (split instead)
|
||||||
|
if not s.c.was_split[s.c.B(0)] and gold.fused[s.c.B(0)]:
|
||||||
|
return gold.fused[s.c.B(0)]
|
||||||
|
elif arc_is_gold(gold, s.S(0), s.B(0)):
|
||||||
return 0
|
return 0
|
||||||
elif s.c.shifted[s.B(0)]:
|
elif s.c.shifted[s.B(0)]:
|
||||||
return push_cost(s, gold, s.B(0))
|
return push_cost(s, gold, s.B(0))
|
||||||
|
@ -374,7 +415,7 @@ cdef class ArcEager(TransitionSystem):
|
||||||
|
|
||||||
property action_types:
|
property action_types:
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
|
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK, SPLIT)
|
||||||
|
|
||||||
def get_cost(self, StateClass state, GoldParse gold, action):
|
def get_cost(self, StateClass state, GoldParse gold, action):
|
||||||
cdef Transition t = self.lookup_transition(action)
|
cdef Transition t = self.lookup_transition(action)
|
||||||
|
@ -513,6 +554,10 @@ cdef class ArcEager(TransitionSystem):
|
||||||
t.is_valid = Break.is_valid
|
t.is_valid = Break.is_valid
|
||||||
t.do = Break.transition
|
t.do = Break.transition
|
||||||
t.get_cost = Break.cost
|
t.get_cost = Break.cost
|
||||||
|
elif move == SPLIT:
|
||||||
|
t.is_valid = Split.is_valid
|
||||||
|
t.do = Split.transition
|
||||||
|
t.get_cost = Split.cost
|
||||||
else:
|
else:
|
||||||
raise Exception(move)
|
raise Exception(move)
|
||||||
return t
|
return t
|
||||||
|
@ -543,6 +588,7 @@ cdef class ArcEager(TransitionSystem):
|
||||||
is_valid[LEFT] = LeftArc.is_valid(st, 0)
|
is_valid[LEFT] = LeftArc.is_valid(st, 0)
|
||||||
is_valid[RIGHT] = RightArc.is_valid(st, 0)
|
is_valid[RIGHT] = RightArc.is_valid(st, 0)
|
||||||
is_valid[BREAK] = Break.is_valid(st, 0)
|
is_valid[BREAK] = Break.is_valid(st, 0)
|
||||||
|
is_valid[SPLIT] = Split.is_valid(st, 0)
|
||||||
cdef int i
|
cdef int i
|
||||||
for i in range(self.n_moves):
|
for i in range(self.n_moves):
|
||||||
output[i] = is_valid[self.c[i].move]
|
output[i] = is_valid[self.c[i].move]
|
||||||
|
@ -561,12 +607,14 @@ cdef class ArcEager(TransitionSystem):
|
||||||
move_cost_funcs[LEFT] = LeftArc.move_cost
|
move_cost_funcs[LEFT] = LeftArc.move_cost
|
||||||
move_cost_funcs[RIGHT] = RightArc.move_cost
|
move_cost_funcs[RIGHT] = RightArc.move_cost
|
||||||
move_cost_funcs[BREAK] = Break.move_cost
|
move_cost_funcs[BREAK] = Break.move_cost
|
||||||
|
move_cost_funcs[SPLIT] = Split.move_cost
|
||||||
|
|
||||||
label_cost_funcs[SHIFT] = Shift.label_cost
|
label_cost_funcs[SHIFT] = Shift.label_cost
|
||||||
label_cost_funcs[REDUCE] = Reduce.label_cost
|
label_cost_funcs[REDUCE] = Reduce.label_cost
|
||||||
label_cost_funcs[LEFT] = LeftArc.label_cost
|
label_cost_funcs[LEFT] = LeftArc.label_cost
|
||||||
label_cost_funcs[RIGHT] = RightArc.label_cost
|
label_cost_funcs[RIGHT] = RightArc.label_cost
|
||||||
label_cost_funcs[BREAK] = Break.label_cost
|
label_cost_funcs[BREAK] = Break.label_cost
|
||||||
|
label_cost_funcs[SPLIT] = Split.label_cost
|
||||||
|
|
||||||
cdef attr_t* labels = gold.c.labels
|
cdef attr_t* labels = gold.c.labels
|
||||||
cdef int* heads = gold.c.heads
|
cdef int* heads = gold.c.heads
|
||||||
|
|
Loading…
Reference in New Issue
Block a user