* Set nogil for oracle functions

This commit is contained in:
Matthew Honnibal 2015-06-10 06:56:35 +02:00
parent 4575e7a60f
commit e5570c9700
2 changed files with 42 additions and 43 deletions

View File

@ -53,7 +53,7 @@ MOVE_NAMES[BREAK] = 'B'
# Helper functions for the arc-eager oracle
cdef int push_cost(StateClass stcls, const GoldParseC* gold, int target) except -1:
cdef int push_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
cdef int cost = 0
cdef int i, S_i
for i in range(stcls.stack_depth()):
@ -66,7 +66,7 @@ cdef int push_cost(StateClass stcls, const GoldParseC* gold, int target) except
return cost
cdef int pop_cost(StateClass stcls, const GoldParseC* gold, int target) except -1:
cdef int pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
cdef int cost = 0
cdef int i, B_i
for i in range(stcls.buffer_length()):
@ -77,7 +77,7 @@ cdef int pop_cost(StateClass stcls, const GoldParseC* gold, int target) except -
break
return cost
cdef int arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) except -1:
cdef int arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) nogil:
if arc_is_gold(gold, head, child):
return 0
elif stcls.H(child) == gold.heads[child]:
@ -88,7 +88,7 @@ cdef int arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child)
return 0
cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) except -1:
cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) nogil:
if gold.labels[child] == -1:
return True
elif _is_gold_root(gold, head) and _is_gold_root(gold, child):
@ -99,7 +99,7 @@ cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) except -1:
return False
cdef bint label_is_gold(const GoldParseC* gold, int head, int child, int label) except -1:
cdef bint label_is_gold(const GoldParseC* gold, int head, int child, int label) nogil:
if gold.labels[child] == -1:
return True
elif label == -1:
@ -110,75 +110,75 @@ cdef bint label_is_gold(const GoldParseC* gold, int head, int child, int label)
return False
cdef bint _is_gold_root(const GoldParseC* gold, int word) except -1:
cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil:
return gold.labels[word] == -1 or gold.heads[word] == word
cdef class Shift:
@staticmethod
cdef bint is_valid(StateClass st, int label) except -1:
cdef bint is_valid(StateClass st, int label) nogil:
return not st.eol()
@staticmethod
cdef int transition(StateClass state, int label) except -1:
cdef int transition(StateClass state, int label) nogil:
# Set the dep label, in case we need it after we reduce
if NON_MONOTONIC:
state._sent[state.B(0)].dep = label
state.push()
@staticmethod
cdef int cost(StateClass st, const GoldParseC* gold, int label) except -1:
cdef int cost(StateClass st, const GoldParseC* gold, int label) nogil:
return Shift.move_cost(st, gold) + Shift.label_cost(st, gold, label)
@staticmethod
cdef int move_cost(StateClass s, const GoldParseC* gold) except -1:
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
return push_cost(s, gold, s.B(0))
@staticmethod
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) except -1:
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
return 0
cdef class Reduce:
@staticmethod
cdef bint is_valid(StateClass st, int label) except -1:
cdef bint is_valid(StateClass st, int label) nogil:
if NON_MONOTONIC:
return st.stack_depth() >= 2 #and not missing_brackets(s)
else:
return st.stack_depth() >= 2 and st.has_head(st.S(0))
@staticmethod
cdef int transition(StateClass st, int label) except -1:
cdef int transition(StateClass st, int label) nogil:
if NON_MONOTONIC and not st.has_head(st.S(0)) and st.stack_depth() >= 2:
st.add_arc(st.S(1), st.S(0), st.S_(0).dep)
st.pop()
@staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label)
@staticmethod
cdef int move_cost(StateClass s, const GoldParseC* gold) except -1:
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
if NON_MONOTONIC:
return pop_cost(s, gold, s.S(0))
else:
return children_in_buffer(s, s.S(0), gold.heads)
@staticmethod
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) except -1:
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
return 0
cdef class LeftArc:
@staticmethod
cdef bint is_valid(StateClass st, int label) except -1:
cdef bint is_valid(StateClass st, int label) nogil:
if NON_MONOTONIC:
return st.stack_depth() >= 1 #and not missing_brackets(s)
else:
return st.stack_depth() >= 1 and not st.has_head(st.S(0))
@staticmethod
cdef int transition(StateClass st, int label) except -1:
cdef int transition(StateClass st, int label) nogil:
# Interpret left-arcs from EOL as attachment to root
if st.eol():
st.add_arc(st.S(0), st.S(0), label)
@ -187,50 +187,50 @@ cdef class LeftArc:
st.pop()
@staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label)
@staticmethod
cdef int move_cost(StateClass s, const GoldParseC* gold) except -1:
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
if arc_is_gold(gold, s.B(0), s.S(0)):
return 0
else:
return pop_cost(s, gold, s.S(0)) + arc_cost(s, gold, s.B(0), s.S(0))
@staticmethod
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) except -1:
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
return arc_is_gold(gold, s.B(0), s.S(0)) and not label_is_gold(gold, s.B(0), s.S(0), label)
cdef class RightArc:
@staticmethod
cdef bint is_valid(StateClass st, int label) except -1:
cdef bint is_valid(StateClass st, int label) nogil:
return st.stack_depth() >= 1 and not st.eol()
@staticmethod
cdef int transition(StateClass st, int label) except -1:
cdef int transition(StateClass st, int label) nogil:
st.add_arc(st.S(0), st.B(0), label)
st.push()
@staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
cdef inline int cost(StateClass s, const GoldParseC* gold, int label) nogil:
return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
@staticmethod
cdef int move_cost(StateClass s, const GoldParseC* gold) except -1:
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
if arc_is_gold(gold, s.S(0), s.B(0)):
return 0
else:
return push_cost(s, gold, s.B(0)) + arc_cost(s, gold, s.S(0), s.B(0))
@staticmethod
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) except -1:
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
return arc_is_gold(gold, s.S(0), s.B(0)) and not label_is_gold(gold, s.S(0), s.B(0), label)
cdef class Break:
@staticmethod
cdef bint is_valid(StateClass st, int label) except -1:
cdef bint is_valid(StateClass st, int label) nogil:
cdef int i
if not USE_BREAK:
return False
@ -256,7 +256,7 @@ cdef class Break:
return True
@staticmethod
cdef int transition(StateClass st, int label) except -1:
cdef int transition(StateClass st, int label) nogil:
st.set_sent_end(st.B(0)-1)
while not st.empty():
if not st.has_head(st.S(0)):
@ -264,15 +264,15 @@ cdef class Break:
st.pop()
@staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
return Break.move_cost(s, gold) + Break.label_cost(s, gold, label)
@staticmethod
cdef int move_cost(StateClass s, const GoldParseC* gold) except -1:
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
# When we break, we Reduce all of the words on the stack.
cdef int cost = 0
# Number of deps between S0...Sn and N0...Nn
cdef int i, B_i, S_i
cdef int i, j, B_i, S_i
for i in range(s.buffer_length()):
B_i = s.B(i)
for j in range(s.stack_depth()):
@ -282,7 +282,7 @@ cdef class Break:
return cost
@staticmethod
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) except -1:
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
return 0
@ -411,18 +411,17 @@ cdef class ArcEager(TransitionSystem):
cdef int* labels = gold.c.labels
cdef int* heads = gold.c.heads
self.set_valid(self._is_valid, stcls)
n_gold = 0
for i in range(self.n_moves):
if not self._is_valid[i]:
output[i] = 9000
else:
if self.c[i].is_valid(stcls, self.c[i].label):
move = self.c[i].move
label = self.c[i].label
if move_costs[move] == -1:
move_costs[move] = move_cost_funcs[move](stcls, &gold.c)
output[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label)
n_gold += output[i] == 0
else:
output[i] = 9000
assert n_gold >= 1
cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *:

View File

@ -16,16 +16,16 @@ cdef struct Transition:
weight_t score
bint (*is_valid)(StateClass state, int label) except -1
int (*get_cost)(StateClass state, const GoldParseC* gold, int label) except -1
int (*do)(StateClass state, int label) except -1
bint (*is_valid)(StateClass state, int label) nogil
int (*get_cost)(StateClass state, const GoldParseC* gold, int label) nogil
int (*do)(StateClass state, int label) nogil
ctypedef int (*get_cost_func_t)(StateClass state, const GoldParseC* gold, int label) except -1
ctypedef int (*move_cost_func_t)(StateClass state, const GoldParseC* gold) except -1
ctypedef int (*label_cost_func_t)(StateClass state, const GoldParseC* gold, int label) except -1
ctypedef int (*get_cost_func_t)(StateClass state, const GoldParseC* gold, int label) nogil
ctypedef int (*move_cost_func_t)(StateClass state, const GoldParseC* gold) nogil
ctypedef int (*label_cost_func_t)(StateClass state, const GoldParseC* gold, int label) nogil
ctypedef int (*do_func_t)(StateClass state, int label) except -1
ctypedef int (*do_func_t)(StateClass state, int label) nogil
cdef class TransitionSystem: