mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
* Set nogil for oracle functions
This commit is contained in:
parent
4575e7a60f
commit
e5570c9700
|
@ -53,7 +53,7 @@ MOVE_NAMES[BREAK] = 'B'
|
|||
|
||||
# Helper functions for the arc-eager oracle
|
||||
|
||||
cdef int push_cost(StateClass stcls, const GoldParseC* gold, int target) except -1:
|
||||
cdef int push_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
|
||||
cdef int cost = 0
|
||||
cdef int i, S_i
|
||||
for i in range(stcls.stack_depth()):
|
||||
|
@ -66,7 +66,7 @@ cdef int push_cost(StateClass stcls, const GoldParseC* gold, int target) except
|
|||
return cost
|
||||
|
||||
|
||||
cdef int pop_cost(StateClass stcls, const GoldParseC* gold, int target) except -1:
|
||||
cdef int pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
|
||||
cdef int cost = 0
|
||||
cdef int i, B_i
|
||||
for i in range(stcls.buffer_length()):
|
||||
|
@ -77,7 +77,7 @@ cdef int pop_cost(StateClass stcls, const GoldParseC* gold, int target) except -
|
|||
break
|
||||
return cost
|
||||
|
||||
cdef int arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) except -1:
|
||||
cdef int arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) nogil:
|
||||
if arc_is_gold(gold, head, child):
|
||||
return 0
|
||||
elif stcls.H(child) == gold.heads[child]:
|
||||
|
@ -88,7 +88,7 @@ cdef int arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child)
|
|||
return 0
|
||||
|
||||
|
||||
cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) except -1:
|
||||
cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) nogil:
|
||||
if gold.labels[child] == -1:
|
||||
return True
|
||||
elif _is_gold_root(gold, head) and _is_gold_root(gold, child):
|
||||
|
@ -99,7 +99,7 @@ cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) except -1:
|
|||
return False
|
||||
|
||||
|
||||
cdef bint label_is_gold(const GoldParseC* gold, int head, int child, int label) except -1:
|
||||
cdef bint label_is_gold(const GoldParseC* gold, int head, int child, int label) nogil:
|
||||
if gold.labels[child] == -1:
|
||||
return True
|
||||
elif label == -1:
|
||||
|
@ -110,75 +110,75 @@ cdef bint label_is_gold(const GoldParseC* gold, int head, int child, int label)
|
|||
return False
|
||||
|
||||
|
||||
cdef bint _is_gold_root(const GoldParseC* gold, int word) except -1:
|
||||
cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil:
|
||||
return gold.labels[word] == -1 or gold.heads[word] == word
|
||||
|
||||
|
||||
cdef class Shift:
|
||||
@staticmethod
|
||||
cdef bint is_valid(StateClass st, int label) except -1:
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
return not st.eol()
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateClass state, int label) except -1:
|
||||
cdef int transition(StateClass state, int label) nogil:
|
||||
# Set the dep label, in case we need it after we reduce
|
||||
if NON_MONOTONIC:
|
||||
state._sent[state.B(0)].dep = label
|
||||
state.push()
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(StateClass st, const GoldParseC* gold, int label) except -1:
|
||||
cdef int cost(StateClass st, const GoldParseC* gold, int label) nogil:
|
||||
return Shift.move_cost(st, gold) + Shift.label_cost(st, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef int move_cost(StateClass s, const GoldParseC* gold) except -1:
|
||||
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
|
||||
return push_cost(s, gold, s.B(0))
|
||||
|
||||
@staticmethod
|
||||
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
return 0
|
||||
|
||||
|
||||
cdef class Reduce:
|
||||
@staticmethod
|
||||
cdef bint is_valid(StateClass st, int label) except -1:
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
if NON_MONOTONIC:
|
||||
return st.stack_depth() >= 2 #and not missing_brackets(s)
|
||||
else:
|
||||
return st.stack_depth() >= 2 and st.has_head(st.S(0))
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateClass st, int label) except -1:
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
if NON_MONOTONIC and not st.has_head(st.S(0)) and st.stack_depth() >= 2:
|
||||
st.add_arc(st.S(1), st.S(0), st.S_(0).dep)
|
||||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef int move_cost(StateClass s, const GoldParseC* gold) except -1:
|
||||
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
|
||||
if NON_MONOTONIC:
|
||||
return pop_cost(s, gold, s.S(0))
|
||||
else:
|
||||
return children_in_buffer(s, s.S(0), gold.heads)
|
||||
|
||||
@staticmethod
|
||||
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
return 0
|
||||
|
||||
|
||||
cdef class LeftArc:
|
||||
@staticmethod
|
||||
cdef bint is_valid(StateClass st, int label) except -1:
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
if NON_MONOTONIC:
|
||||
return st.stack_depth() >= 1 #and not missing_brackets(s)
|
||||
else:
|
||||
return st.stack_depth() >= 1 and not st.has_head(st.S(0))
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateClass st, int label) except -1:
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
# Interpret left-arcs from EOL as attachment to root
|
||||
if st.eol():
|
||||
st.add_arc(st.S(0), st.S(0), label)
|
||||
|
@ -187,50 +187,50 @@ cdef class LeftArc:
|
|||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef int move_cost(StateClass s, const GoldParseC* gold) except -1:
|
||||
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
|
||||
if arc_is_gold(gold, s.B(0), s.S(0)):
|
||||
return 0
|
||||
else:
|
||||
return pop_cost(s, gold, s.S(0)) + arc_cost(s, gold, s.B(0), s.S(0))
|
||||
|
||||
@staticmethod
|
||||
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
return arc_is_gold(gold, s.B(0), s.S(0)) and not label_is_gold(gold, s.B(0), s.S(0), label)
|
||||
|
||||
|
||||
cdef class RightArc:
|
||||
@staticmethod
|
||||
cdef bint is_valid(StateClass st, int label) except -1:
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
return st.stack_depth() >= 1 and not st.eol()
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateClass st, int label) except -1:
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
st.add_arc(st.S(0), st.B(0), label)
|
||||
st.push()
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
cdef inline int cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef int move_cost(StateClass s, const GoldParseC* gold) except -1:
|
||||
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
|
||||
if arc_is_gold(gold, s.S(0), s.B(0)):
|
||||
return 0
|
||||
else:
|
||||
return push_cost(s, gold, s.B(0)) + arc_cost(s, gold, s.S(0), s.B(0))
|
||||
|
||||
@staticmethod
|
||||
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
return arc_is_gold(gold, s.S(0), s.B(0)) and not label_is_gold(gold, s.S(0), s.B(0), label)
|
||||
|
||||
|
||||
cdef class Break:
|
||||
@staticmethod
|
||||
cdef bint is_valid(StateClass st, int label) except -1:
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
cdef int i
|
||||
if not USE_BREAK:
|
||||
return False
|
||||
|
@ -256,7 +256,7 @@ cdef class Break:
|
|||
return True
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateClass st, int label) except -1:
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
st.set_sent_end(st.B(0)-1)
|
||||
while not st.empty():
|
||||
if not st.has_head(st.S(0)):
|
||||
|
@ -264,15 +264,15 @@ cdef class Break:
|
|||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
return Break.move_cost(s, gold) + Break.label_cost(s, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef int move_cost(StateClass s, const GoldParseC* gold) except -1:
|
||||
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
|
||||
# When we break, we Reduce all of the words on the stack.
|
||||
cdef int cost = 0
|
||||
# Number of deps between S0...Sn and N0...Nn
|
||||
cdef int i, B_i, S_i
|
||||
cdef int i, j, B_i, S_i
|
||||
for i in range(s.buffer_length()):
|
||||
B_i = s.B(i)
|
||||
for j in range(s.stack_depth()):
|
||||
|
@ -282,7 +282,7 @@ cdef class Break:
|
|||
return cost
|
||||
|
||||
@staticmethod
|
||||
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
return 0
|
||||
|
||||
|
||||
|
@ -411,18 +411,17 @@ cdef class ArcEager(TransitionSystem):
|
|||
cdef int* labels = gold.c.labels
|
||||
cdef int* heads = gold.c.heads
|
||||
|
||||
self.set_valid(self._is_valid, stcls)
|
||||
n_gold = 0
|
||||
for i in range(self.n_moves):
|
||||
if not self._is_valid[i]:
|
||||
output[i] = 9000
|
||||
else:
|
||||
if self.c[i].is_valid(stcls, self.c[i].label):
|
||||
move = self.c[i].move
|
||||
label = self.c[i].label
|
||||
if move_costs[move] == -1:
|
||||
move_costs[move] = move_cost_funcs[move](stcls, &gold.c)
|
||||
output[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label)
|
||||
n_gold += output[i] == 0
|
||||
else:
|
||||
output[i] = 9000
|
||||
assert n_gold >= 1
|
||||
|
||||
cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *:
|
||||
|
|
|
@ -16,16 +16,16 @@ cdef struct Transition:
|
|||
|
||||
weight_t score
|
||||
|
||||
bint (*is_valid)(StateClass state, int label) except -1
|
||||
int (*get_cost)(StateClass state, const GoldParseC* gold, int label) except -1
|
||||
int (*do)(StateClass state, int label) except -1
|
||||
bint (*is_valid)(StateClass state, int label) nogil
|
||||
int (*get_cost)(StateClass state, const GoldParseC* gold, int label) nogil
|
||||
int (*do)(StateClass state, int label) nogil
|
||||
|
||||
|
||||
ctypedef int (*get_cost_func_t)(StateClass state, const GoldParseC* gold, int label) except -1
|
||||
ctypedef int (*move_cost_func_t)(StateClass state, const GoldParseC* gold) except -1
|
||||
ctypedef int (*label_cost_func_t)(StateClass state, const GoldParseC* gold, int label) except -1
|
||||
ctypedef int (*get_cost_func_t)(StateClass state, const GoldParseC* gold, int label) nogil
|
||||
ctypedef int (*move_cost_func_t)(StateClass state, const GoldParseC* gold) nogil
|
||||
ctypedef int (*label_cost_func_t)(StateClass state, const GoldParseC* gold, int label) nogil
|
||||
|
||||
ctypedef int (*do_func_t)(StateClass state, int label) except -1
|
||||
ctypedef int (*do_func_t)(StateClass state, int label) nogil
|
||||
|
||||
|
||||
cdef class TransitionSystem:
|
||||
|
|
Loading…
Reference in New Issue
Block a user