mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 09:44:36 +03:00
* Set nogil for oracle functions
This commit is contained in:
parent
4575e7a60f
commit
e5570c9700
|
@ -53,7 +53,7 @@ MOVE_NAMES[BREAK] = 'B'
|
||||||
|
|
||||||
# Helper functions for the arc-eager oracle
|
# Helper functions for the arc-eager oracle
|
||||||
|
|
||||||
cdef int push_cost(StateClass stcls, const GoldParseC* gold, int target) except -1:
|
cdef int push_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
|
||||||
cdef int cost = 0
|
cdef int cost = 0
|
||||||
cdef int i, S_i
|
cdef int i, S_i
|
||||||
for i in range(stcls.stack_depth()):
|
for i in range(stcls.stack_depth()):
|
||||||
|
@ -66,7 +66,7 @@ cdef int push_cost(StateClass stcls, const GoldParseC* gold, int target) except
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
|
|
||||||
cdef int pop_cost(StateClass stcls, const GoldParseC* gold, int target) except -1:
|
cdef int pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
|
||||||
cdef int cost = 0
|
cdef int cost = 0
|
||||||
cdef int i, B_i
|
cdef int i, B_i
|
||||||
for i in range(stcls.buffer_length()):
|
for i in range(stcls.buffer_length()):
|
||||||
|
@ -77,7 +77,7 @@ cdef int pop_cost(StateClass stcls, const GoldParseC* gold, int target) except -
|
||||||
break
|
break
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
cdef int arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) except -1:
|
cdef int arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) nogil:
|
||||||
if arc_is_gold(gold, head, child):
|
if arc_is_gold(gold, head, child):
|
||||||
return 0
|
return 0
|
||||||
elif stcls.H(child) == gold.heads[child]:
|
elif stcls.H(child) == gold.heads[child]:
|
||||||
|
@ -88,7 +88,7 @@ cdef int arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) except -1:
|
cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) nogil:
|
||||||
if gold.labels[child] == -1:
|
if gold.labels[child] == -1:
|
||||||
return True
|
return True
|
||||||
elif _is_gold_root(gold, head) and _is_gold_root(gold, child):
|
elif _is_gold_root(gold, head) and _is_gold_root(gold, child):
|
||||||
|
@ -99,7 +99,7 @@ cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) except -1:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
cdef bint label_is_gold(const GoldParseC* gold, int head, int child, int label) except -1:
|
cdef bint label_is_gold(const GoldParseC* gold, int head, int child, int label) nogil:
|
||||||
if gold.labels[child] == -1:
|
if gold.labels[child] == -1:
|
||||||
return True
|
return True
|
||||||
elif label == -1:
|
elif label == -1:
|
||||||
|
@ -110,75 +110,75 @@ cdef bint label_is_gold(const GoldParseC* gold, int head, int child, int label)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
cdef bint _is_gold_root(const GoldParseC* gold, int word) except -1:
|
cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil:
|
||||||
return gold.labels[word] == -1 or gold.heads[word] == word
|
return gold.labels[word] == -1 or gold.heads[word] == word
|
||||||
|
|
||||||
|
|
||||||
cdef class Shift:
|
cdef class Shift:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef bint is_valid(StateClass st, int label) except -1:
|
cdef bint is_valid(StateClass st, int label) nogil:
|
||||||
return not st.eol()
|
return not st.eol()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(StateClass state, int label) except -1:
|
cdef int transition(StateClass state, int label) nogil:
|
||||||
# Set the dep label, in case we need it after we reduce
|
# Set the dep label, in case we need it after we reduce
|
||||||
if NON_MONOTONIC:
|
if NON_MONOTONIC:
|
||||||
state._sent[state.B(0)].dep = label
|
state._sent[state.B(0)].dep = label
|
||||||
state.push()
|
state.push()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int cost(StateClass st, const GoldParseC* gold, int label) except -1:
|
cdef int cost(StateClass st, const GoldParseC* gold, int label) nogil:
|
||||||
return Shift.move_cost(st, gold) + Shift.label_cost(st, gold, label)
|
return Shift.move_cost(st, gold) + Shift.label_cost(st, gold, label)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int move_cost(StateClass s, const GoldParseC* gold) except -1:
|
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
|
||||||
return push_cost(s, gold, s.B(0))
|
return push_cost(s, gold, s.B(0))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
cdef class Reduce:
|
cdef class Reduce:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef bint is_valid(StateClass st, int label) except -1:
|
cdef bint is_valid(StateClass st, int label) nogil:
|
||||||
if NON_MONOTONIC:
|
if NON_MONOTONIC:
|
||||||
return st.stack_depth() >= 2 #and not missing_brackets(s)
|
return st.stack_depth() >= 2 #and not missing_brackets(s)
|
||||||
else:
|
else:
|
||||||
return st.stack_depth() >= 2 and st.has_head(st.S(0))
|
return st.stack_depth() >= 2 and st.has_head(st.S(0))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(StateClass st, int label) except -1:
|
cdef int transition(StateClass st, int label) nogil:
|
||||||
if NON_MONOTONIC and not st.has_head(st.S(0)) and st.stack_depth() >= 2:
|
if NON_MONOTONIC and not st.has_head(st.S(0)) and st.stack_depth() >= 2:
|
||||||
st.add_arc(st.S(1), st.S(0), st.S_(0).dep)
|
st.add_arc(st.S(1), st.S(0), st.S_(0).dep)
|
||||||
st.pop()
|
st.pop()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||||
return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label)
|
return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int move_cost(StateClass s, const GoldParseC* gold) except -1:
|
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
|
||||||
if NON_MONOTONIC:
|
if NON_MONOTONIC:
|
||||||
return pop_cost(s, gold, s.S(0))
|
return pop_cost(s, gold, s.S(0))
|
||||||
else:
|
else:
|
||||||
return children_in_buffer(s, s.S(0), gold.heads)
|
return children_in_buffer(s, s.S(0), gold.heads)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
cdef class LeftArc:
|
cdef class LeftArc:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef bint is_valid(StateClass st, int label) except -1:
|
cdef bint is_valid(StateClass st, int label) nogil:
|
||||||
if NON_MONOTONIC:
|
if NON_MONOTONIC:
|
||||||
return st.stack_depth() >= 1 #and not missing_brackets(s)
|
return st.stack_depth() >= 1 #and not missing_brackets(s)
|
||||||
else:
|
else:
|
||||||
return st.stack_depth() >= 1 and not st.has_head(st.S(0))
|
return st.stack_depth() >= 1 and not st.has_head(st.S(0))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(StateClass st, int label) except -1:
|
cdef int transition(StateClass st, int label) nogil:
|
||||||
# Interpret left-arcs from EOL as attachment to root
|
# Interpret left-arcs from EOL as attachment to root
|
||||||
if st.eol():
|
if st.eol():
|
||||||
st.add_arc(st.S(0), st.S(0), label)
|
st.add_arc(st.S(0), st.S(0), label)
|
||||||
|
@ -187,50 +187,50 @@ cdef class LeftArc:
|
||||||
st.pop()
|
st.pop()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||||
return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label)
|
return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int move_cost(StateClass s, const GoldParseC* gold) except -1:
|
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
|
||||||
if arc_is_gold(gold, s.B(0), s.S(0)):
|
if arc_is_gold(gold, s.B(0), s.S(0)):
|
||||||
return 0
|
return 0
|
||||||
else:
|
else:
|
||||||
return pop_cost(s, gold, s.S(0)) + arc_cost(s, gold, s.B(0), s.S(0))
|
return pop_cost(s, gold, s.S(0)) + arc_cost(s, gold, s.B(0), s.S(0))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||||
return arc_is_gold(gold, s.B(0), s.S(0)) and not label_is_gold(gold, s.B(0), s.S(0), label)
|
return arc_is_gold(gold, s.B(0), s.S(0)) and not label_is_gold(gold, s.B(0), s.S(0), label)
|
||||||
|
|
||||||
|
|
||||||
cdef class RightArc:
|
cdef class RightArc:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef bint is_valid(StateClass st, int label) except -1:
|
cdef bint is_valid(StateClass st, int label) nogil:
|
||||||
return st.stack_depth() >= 1 and not st.eol()
|
return st.stack_depth() >= 1 and not st.eol()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(StateClass st, int label) except -1:
|
cdef int transition(StateClass st, int label) nogil:
|
||||||
st.add_arc(st.S(0), st.B(0), label)
|
st.add_arc(st.S(0), st.B(0), label)
|
||||||
st.push()
|
st.push()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
cdef inline int cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||||
return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
|
return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int move_cost(StateClass s, const GoldParseC* gold) except -1:
|
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
|
||||||
if arc_is_gold(gold, s.S(0), s.B(0)):
|
if arc_is_gold(gold, s.S(0), s.B(0)):
|
||||||
return 0
|
return 0
|
||||||
else:
|
else:
|
||||||
return push_cost(s, gold, s.B(0)) + arc_cost(s, gold, s.S(0), s.B(0))
|
return push_cost(s, gold, s.B(0)) + arc_cost(s, gold, s.S(0), s.B(0))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||||
return arc_is_gold(gold, s.S(0), s.B(0)) and not label_is_gold(gold, s.S(0), s.B(0), label)
|
return arc_is_gold(gold, s.S(0), s.B(0)) and not label_is_gold(gold, s.S(0), s.B(0), label)
|
||||||
|
|
||||||
|
|
||||||
cdef class Break:
|
cdef class Break:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef bint is_valid(StateClass st, int label) except -1:
|
cdef bint is_valid(StateClass st, int label) nogil:
|
||||||
cdef int i
|
cdef int i
|
||||||
if not USE_BREAK:
|
if not USE_BREAK:
|
||||||
return False
|
return False
|
||||||
|
@ -256,7 +256,7 @@ cdef class Break:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(StateClass st, int label) except -1:
|
cdef int transition(StateClass st, int label) nogil:
|
||||||
st.set_sent_end(st.B(0)-1)
|
st.set_sent_end(st.B(0)-1)
|
||||||
while not st.empty():
|
while not st.empty():
|
||||||
if not st.has_head(st.S(0)):
|
if not st.has_head(st.S(0)):
|
||||||
|
@ -264,15 +264,15 @@ cdef class Break:
|
||||||
st.pop()
|
st.pop()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||||
return Break.move_cost(s, gold) + Break.label_cost(s, gold, label)
|
return Break.move_cost(s, gold) + Break.label_cost(s, gold, label)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int move_cost(StateClass s, const GoldParseC* gold) except -1:
|
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
|
||||||
# When we break, we Reduce all of the words on the stack.
|
# When we break, we Reduce all of the words on the stack.
|
||||||
cdef int cost = 0
|
cdef int cost = 0
|
||||||
# Number of deps between S0...Sn and N0...Nn
|
# Number of deps between S0...Sn and N0...Nn
|
||||||
cdef int i, B_i, S_i
|
cdef int i, j, B_i, S_i
|
||||||
for i in range(s.buffer_length()):
|
for i in range(s.buffer_length()):
|
||||||
B_i = s.B(i)
|
B_i = s.B(i)
|
||||||
for j in range(s.stack_depth()):
|
for j in range(s.stack_depth()):
|
||||||
|
@ -282,7 +282,7 @@ cdef class Break:
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
@ -411,18 +411,17 @@ cdef class ArcEager(TransitionSystem):
|
||||||
cdef int* labels = gold.c.labels
|
cdef int* labels = gold.c.labels
|
||||||
cdef int* heads = gold.c.heads
|
cdef int* heads = gold.c.heads
|
||||||
|
|
||||||
self.set_valid(self._is_valid, stcls)
|
|
||||||
n_gold = 0
|
n_gold = 0
|
||||||
for i in range(self.n_moves):
|
for i in range(self.n_moves):
|
||||||
if not self._is_valid[i]:
|
if self.c[i].is_valid(stcls, self.c[i].label):
|
||||||
output[i] = 9000
|
|
||||||
else:
|
|
||||||
move = self.c[i].move
|
move = self.c[i].move
|
||||||
label = self.c[i].label
|
label = self.c[i].label
|
||||||
if move_costs[move] == -1:
|
if move_costs[move] == -1:
|
||||||
move_costs[move] = move_cost_funcs[move](stcls, &gold.c)
|
move_costs[move] = move_cost_funcs[move](stcls, &gold.c)
|
||||||
output[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label)
|
output[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label)
|
||||||
n_gold += output[i] == 0
|
n_gold += output[i] == 0
|
||||||
|
else:
|
||||||
|
output[i] = 9000
|
||||||
assert n_gold >= 1
|
assert n_gold >= 1
|
||||||
|
|
||||||
cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *:
|
cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *:
|
||||||
|
|
|
@ -16,16 +16,16 @@ cdef struct Transition:
|
||||||
|
|
||||||
weight_t score
|
weight_t score
|
||||||
|
|
||||||
bint (*is_valid)(StateClass state, int label) except -1
|
bint (*is_valid)(StateClass state, int label) nogil
|
||||||
int (*get_cost)(StateClass state, const GoldParseC* gold, int label) except -1
|
int (*get_cost)(StateClass state, const GoldParseC* gold, int label) nogil
|
||||||
int (*do)(StateClass state, int label) except -1
|
int (*do)(StateClass state, int label) nogil
|
||||||
|
|
||||||
|
|
||||||
ctypedef int (*get_cost_func_t)(StateClass state, const GoldParseC* gold, int label) except -1
|
ctypedef int (*get_cost_func_t)(StateClass state, const GoldParseC* gold, int label) nogil
|
||||||
ctypedef int (*move_cost_func_t)(StateClass state, const GoldParseC* gold) except -1
|
ctypedef int (*move_cost_func_t)(StateClass state, const GoldParseC* gold) nogil
|
||||||
ctypedef int (*label_cost_func_t)(StateClass state, const GoldParseC* gold, int label) except -1
|
ctypedef int (*label_cost_func_t)(StateClass state, const GoldParseC* gold, int label) nogil
|
||||||
|
|
||||||
ctypedef int (*do_func_t)(StateClass state, int label) except -1
|
ctypedef int (*do_func_t)(StateClass state, int label) nogil
|
||||||
|
|
||||||
|
|
||||||
cdef class TransitionSystem:
|
cdef class TransitionSystem:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user