* Pass a StateC pointer into the transition and validation methods in the parser, so that the GIL can be released over a batch of documents

This commit is contained in:
Matthew Honnibal 2016-02-01 02:58:14 +01:00
parent daaad66448
commit a47f00901b
6 changed files with 66 additions and 63 deletions

View File

@ -63,53 +63,53 @@ cdef cppclass StateC:
free(this._stack - PADDING)
free(this.shifted - PADDING)
int S(int i) nogil:
int S(int i) nogil const:
if i >= this._s_i:
return -1
return this._stack[this._s_i - (i+1)]
int B(int i) nogil:
int B(int i) nogil const:
if (i + this._b_i) >= this.length:
return -1
return this._buffer[this._b_i + i]
const TokenC* S_(int i) nogil:
const TokenC* S_(int i) nogil const:
return this.safe_get(this.S(i))
const TokenC* B_(int i) nogil:
const TokenC* B_(int i) nogil const:
return this.safe_get(this.B(i))
const TokenC* H_(int i) nogil:
const TokenC* H_(int i) nogil const:
return this.safe_get(this.H(i))
const TokenC* E_(int i) nogil:
const TokenC* E_(int i) nogil const:
return this.safe_get(this.E(i))
const TokenC* L_(int i, int idx) nogil:
const TokenC* L_(int i, int idx) nogil const:
return this.safe_get(this.L(i, idx))
const TokenC* R_(int i, int idx) nogil:
const TokenC* R_(int i, int idx) nogil const:
return this.safe_get(this.R(i, idx))
const TokenC* safe_get(int i) nogil:
const TokenC* safe_get(int i) nogil const:
if i < 0 or i >= this.length:
return &this._empty_token
else:
return &this._sent[i]
int H(int i) nogil:
int H(int i) nogil const:
if i < 0 or i >= this.length:
return -1
return this._sent[i].head + i
int E(int i) nogil:
int E(int i) nogil const:
if this._e_i <= 0 or this._e_i >= this.length:
return 0
if i < 0 or i >= this._e_i:
return 0
return this._ents[this._e_i - (i+1)].start
int L(int i, int idx) nogil:
int L(int i, int idx) nogil const:
if idx < 1:
return -1
if i < 0 or i >= this.length:
@ -135,7 +135,7 @@ cdef cppclass StateC:
ptr += 1
return -1
int R(int i, int idx) nogil:
int R(int i, int idx) nogil const:
if idx < 1:
return -1
if i < 0 or i >= this.length:
@ -159,39 +159,39 @@ cdef cppclass StateC:
ptr -= 1
return -1
bint empty() nogil:
bint empty() nogil const:
return this._s_i <= 0
bint eol() nogil:
bint eol() nogil const:
return this.buffer_length() == 0
bint at_break() nogil:
bint at_break() nogil const:
return this._break != -1
bint is_final() nogil:
bint is_final() nogil const:
return this.stack_depth() <= 0 and this._b_i >= this.length
bint has_head(int i) nogil:
bint has_head(int i) nogil const:
return this.safe_get(i).head != 0
int n_L(int i) nogil:
int n_L(int i) nogil const:
return this.safe_get(i).l_kids
int n_R(int i) nogil:
int n_R(int i) nogil const:
return this.safe_get(i).r_kids
bint stack_is_connected() nogil:
bint stack_is_connected() nogil const:
return False
bint entity_is_open() nogil:
bint entity_is_open() nogil const:
if this._e_i < 1:
return False
return this._ents[this._e_i-1].end == -1
int stack_depth() nogil:
int stack_depth() nogil const:
return this._s_i
int buffer_length() nogil:
int buffer_length() nogil const:
if this._break != -1:
return this._break - this._b_i
else:

View File

@ -17,6 +17,7 @@ from libc.string cimport memcpy
from cymem.cymem cimport Pool
from .stateclass cimport StateClass
from ._state cimport StateC
DEF NON_MONOTONIC = True
@ -57,7 +58,7 @@ cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) no
cost += 1
if gold.heads[S_i] == target and (NON_MONOTONIC or not stcls.has_head(S_i)):
cost += 1
cost += Break.is_valid(stcls, -1) and Break.move_cost(stcls, gold) == 0
cost += Break.is_valid(stcls.c, -1) and Break.move_cost(stcls, gold) == 0
return cost
@ -70,7 +71,7 @@ cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nog
cost += gold.heads[target] == B_i
if gold.heads[B_i] == B_i or gold.heads[B_i] < target:
break
if Break.is_valid(stcls, -1) and Break.move_cost(stcls, gold) == 0:
if Break.is_valid(stcls.c, -1) and Break.move_cost(stcls, gold) == 0:
cost += 1
return cost
@ -115,11 +116,11 @@ cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil:
cdef class Shift:
@staticmethod
cdef bint is_valid(StateClass st, int label) nogil:
return st.buffer_length() >= 2 and not st.c.shifted[st.B(0)] and not st.B_(0).sent_start
cdef bint is_valid(const StateC* st, int label) nogil:
return st.buffer_length() >= 2 and not st.shifted[st.B(0)] and not st.B_(0).sent_start
@staticmethod
cdef int transition(StateClass st, int label) nogil:
cdef int transition(StateC* st, int label) nogil:
st.push()
st.fast_forward()
@ -138,11 +139,11 @@ cdef class Shift:
cdef class Reduce:
@staticmethod
cdef bint is_valid(StateClass st, int label) nogil:
cdef bint is_valid(const StateC* st, int label) nogil:
return st.stack_depth() >= 2
@staticmethod
cdef int transition(StateClass st, int label) nogil:
cdef int transition(StateC* st, int label) nogil:
if st.has_head(st.S(0)):
st.pop()
else:
@ -164,11 +165,11 @@ cdef class Reduce:
cdef class LeftArc:
@staticmethod
cdef bint is_valid(StateClass st, int label) nogil:
cdef bint is_valid(const StateC* st, int label) nogil:
return not st.B_(0).sent_start
@staticmethod
cdef int transition(StateClass st, int label) nogil:
cdef int transition(StateC* st, int label) nogil:
st.add_arc(st.B(0), st.S(0), label)
st.pop()
st.fast_forward()
@ -197,11 +198,11 @@ cdef class LeftArc:
cdef class RightArc:
@staticmethod
cdef bint is_valid(StateClass st, int label) nogil:
cdef bint is_valid(const StateC* st, int label) nogil:
return not st.B_(0).sent_start
@staticmethod
cdef int transition(StateClass st, int label) nogil:
cdef int transition(StateC* st, int label) nogil:
st.add_arc(st.S(0), st.B(0), label)
st.push()
st.fast_forward()
@ -226,7 +227,7 @@ cdef class RightArc:
cdef class Break:
@staticmethod
cdef bint is_valid(StateClass st, int label) nogil:
cdef bint is_valid(const StateC* st, int label) nogil:
cdef int i
if not USE_BREAK:
return False
@ -243,7 +244,7 @@ cdef class Break:
return True
@staticmethod
cdef int transition(StateClass st, int label) nogil:
cdef int transition(StateC* st, int label) nogil:
st.set_break(st.B(0))
st.fast_forward()
@ -396,11 +397,11 @@ cdef class ArcEager(TransitionSystem):
cdef int set_valid(self, int* output, StateClass stcls) nogil:
cdef bint[N_MOVES] is_valid
is_valid[SHIFT] = Shift.is_valid(stcls, -1)
is_valid[REDUCE] = Reduce.is_valid(stcls, -1)
is_valid[LEFT] = LeftArc.is_valid(stcls, -1)
is_valid[RIGHT] = RightArc.is_valid(stcls, -1)
is_valid[BREAK] = Break.is_valid(stcls, -1)
is_valid[SHIFT] = Shift.is_valid(stcls.c, -1)
is_valid[REDUCE] = Reduce.is_valid(stcls.c, -1)
is_valid[LEFT] = LeftArc.is_valid(stcls.c, -1)
is_valid[RIGHT] = RightArc.is_valid(stcls.c, -1)
is_valid[BREAK] = Break.is_valid(stcls.c, -1)
cdef int i
for i in range(self.n_moves):
output[i] = is_valid[self.c[i].move]
@ -430,7 +431,7 @@ cdef class ArcEager(TransitionSystem):
n_gold = 0
for i in range(self.n_moves):
if self.c[i].is_valid(stcls, self.c[i].label):
if self.c[i].is_valid(stcls.c, self.c[i].label):
is_valid[i] = True
move = self.c[i].move
label = self.c[i].label

View File

@ -11,6 +11,7 @@ from ..gold cimport GoldParse
from ..attrs cimport ENT_TYPE, ENT_IOB
from .stateclass cimport StateClass
from ._state cimport StateC
cdef enum:
@ -150,11 +151,11 @@ cdef class BiluoPushDown(TransitionSystem):
cdef class Missing:
@staticmethod
cdef bint is_valid(StateClass st, int label) nogil:
cdef bint is_valid(const StateC* st, int label) nogil:
return False
@staticmethod
cdef int transition(StateClass s, int label) nogil:
cdef int transition(StateC* s, int label) nogil:
pass
@staticmethod
@ -164,7 +165,7 @@ cdef class Missing:
cdef class Begin:
@staticmethod
cdef bint is_valid(StateClass st, int label) nogil:
cdef bint is_valid(const StateC* st, int label) nogil:
# Ensure we don't clobber preset entities. If no entity preset,
# ent_iob is 0
cdef int preset_ent_iob = st.B_(0).ent_iob
@ -188,7 +189,7 @@ cdef class Begin:
return label != 0 and not st.entity_is_open()
@staticmethod
cdef int transition(StateClass st, int label) nogil:
cdef int transition(StateC* st, int label) nogil:
st.open_ent(label)
st.set_ent_tag(st.B(0), 3, label)
st.push()
@ -214,7 +215,7 @@ cdef class Begin:
cdef class In:
@staticmethod
cdef bint is_valid(StateClass st, int label) nogil:
cdef bint is_valid(const StateC* st, int label) nogil:
cdef int preset_ent_iob = st.B_(0).ent_iob
if preset_ent_iob == 2:
return False
@ -230,7 +231,7 @@ cdef class In:
return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label
@staticmethod
cdef int transition(StateClass st, int label) nogil:
cdef int transition(StateC* st, int label) nogil:
st.set_ent_tag(st.B(0), 1, label)
st.push()
st.pop()
@ -266,13 +267,13 @@ cdef class In:
cdef class Last:
@staticmethod
cdef bint is_valid(StateClass st, int label) nogil:
cdef bint is_valid(const StateC* st, int label) nogil:
if st.B_(1).ent_iob == 1:
return False
return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label
@staticmethod
cdef int transition(StateClass st, int label) nogil:
cdef int transition(StateC* st, int label) nogil:
st.close_ent()
st.set_ent_tag(st.B(0), 1, label)
st.push()
@ -308,7 +309,7 @@ cdef class Last:
cdef class Unit:
@staticmethod
cdef bint is_valid(StateClass st, int label) nogil:
cdef bint is_valid(const StateC* st, int label) nogil:
cdef int preset_ent_iob = st.B_(0).ent_iob
if preset_ent_iob == 2:
return False
@ -321,7 +322,7 @@ cdef class Unit:
return label != 0 and not st.entity_is_open()
@staticmethod
cdef int transition(StateClass st, int label) nogil:
cdef int transition(StateC* st, int label) nogil:
st.open_ent(label)
st.close_ent()
st.set_ent_tag(st.B(0), 3, label)
@ -348,7 +349,7 @@ cdef class Unit:
cdef class Out:
@staticmethod
cdef bint is_valid(StateClass st, int label) nogil:
cdef bint is_valid(const StateC* st, int label) nogil:
cdef int preset_ent_iob = st.B_(0).ent_iob
if preset_ent_iob == 3:
return False
@ -357,7 +358,7 @@ cdef class Out:
return not st.entity_is_open()
@staticmethod
cdef int transition(StateClass st, int label) nogil:
cdef int transition(StateC* st, int label) nogil:
st.set_ent_tag(st.B(0), 2, 0)
st.push()
st.pop()

View File

@ -124,7 +124,7 @@ cdef class Parser:
with gil:
move_name = self.moves.move_name(action.move, action.label)
raise ValueError("Illegal action: %s" % move_name)
action.do(stcls, action.label)
action.do(stcls.c, action.label)
memset(eg.c.scores, 0, sizeof(eg.c.scores[0]) * eg.c.nr_class)
memset(eg.c.costs, 0, sizeof(eg.c.costs[0]) * eg.c.nr_class)
for i in range(eg.c.nr_class):
@ -151,7 +151,7 @@ cdef class Parser:
guess = VecVec.arg_max_if_true(eg.c.scores, eg.c.is_valid, eg.c.nr_class)
action = self.moves.c[eg.guess]
action.do(stcls, action.label)
action.do(stcls.c, action.label)
loss += eg.costs[eg.guess]
eg.reset_classes(eg.nr_class)
return loss
@ -230,7 +230,7 @@ cdef class StepwiseState:
action = self.parser.moves.c[clas]
else:
action = self.parser.moves.lookup_transition(action_name)
action.do(self.stcls, action.label)
action.do(self.stcls.c, action.label)
def finish(self):
if self.stcls.is_final():

View File

@ -7,6 +7,7 @@ from ..gold cimport GoldParseC
from ..strings cimport StringStore
from .stateclass cimport StateClass
from ._state cimport StateC
cdef struct Transition:
@ -16,16 +17,16 @@ cdef struct Transition:
weight_t score
bint (*is_valid)(StateClass state, int label) nogil
bint (*is_valid)(const StateC* state, int label) nogil
weight_t (*get_cost)(StateClass state, const GoldParseC* gold, int label) nogil
int (*do)(StateClass state, int label) nogil
int (*do)(StateC* state, int label) nogil
ctypedef weight_t (*get_cost_func_t)(StateClass state, const GoldParseC* gold, int label) nogil
ctypedef weight_t (*move_cost_func_t)(StateClass state, const GoldParseC* gold) nogil
ctypedef weight_t (*label_cost_func_t)(StateClass state, const GoldParseC* gold, int label) nogil
ctypedef int (*do_func_t)(StateClass state, int label) nogil
ctypedef int (*do_func_t)(StateC* state, int label) nogil
cdef class TransitionSystem:

View File

@ -64,12 +64,12 @@ cdef class TransitionSystem:
def is_valid(self, StateClass stcls, move_name):
action = self.lookup_transition(move_name)
return action.is_valid(stcls, action.label)
return action.is_valid(stcls.c, action.label)
cdef int set_valid(self, int* is_valid, StateClass stcls) nogil:
cdef int i
for i in range(self.n_moves):
is_valid[i] = self.c[i].is_valid(stcls, self.c[i].label)
is_valid[i] = self.c[i].is_valid(stcls.c, self.c[i].label)
cdef int set_costs(self, int* is_valid, weight_t* costs,
StateClass stcls, GoldParse gold) except -1: