mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 10:46:29 +03:00
* Pass a StateC pointer into the transition and validation methods in the parser, so that the GIL can be released over a batch of documents
This commit is contained in:
parent
daaad66448
commit
a47f00901b
|
@ -63,53 +63,53 @@ cdef cppclass StateC:
|
|||
free(this._stack - PADDING)
|
||||
free(this.shifted - PADDING)
|
||||
|
||||
int S(int i) nogil:
|
||||
int S(int i) nogil const:
|
||||
if i >= this._s_i:
|
||||
return -1
|
||||
return this._stack[this._s_i - (i+1)]
|
||||
|
||||
int B(int i) nogil:
|
||||
int B(int i) nogil const:
|
||||
if (i + this._b_i) >= this.length:
|
||||
return -1
|
||||
return this._buffer[this._b_i + i]
|
||||
|
||||
const TokenC* S_(int i) nogil:
|
||||
const TokenC* S_(int i) nogil const:
|
||||
return this.safe_get(this.S(i))
|
||||
|
||||
const TokenC* B_(int i) nogil:
|
||||
const TokenC* B_(int i) nogil const:
|
||||
return this.safe_get(this.B(i))
|
||||
|
||||
const TokenC* H_(int i) nogil:
|
||||
const TokenC* H_(int i) nogil const:
|
||||
return this.safe_get(this.H(i))
|
||||
|
||||
const TokenC* E_(int i) nogil:
|
||||
const TokenC* E_(int i) nogil const:
|
||||
return this.safe_get(this.E(i))
|
||||
|
||||
const TokenC* L_(int i, int idx) nogil:
|
||||
const TokenC* L_(int i, int idx) nogil const:
|
||||
return this.safe_get(this.L(i, idx))
|
||||
|
||||
const TokenC* R_(int i, int idx) nogil:
|
||||
const TokenC* R_(int i, int idx) nogil const:
|
||||
return this.safe_get(this.R(i, idx))
|
||||
|
||||
const TokenC* safe_get(int i) nogil:
|
||||
const TokenC* safe_get(int i) nogil const:
|
||||
if i < 0 or i >= this.length:
|
||||
return &this._empty_token
|
||||
else:
|
||||
return &this._sent[i]
|
||||
|
||||
int H(int i) nogil:
|
||||
int H(int i) nogil const:
|
||||
if i < 0 or i >= this.length:
|
||||
return -1
|
||||
return this._sent[i].head + i
|
||||
|
||||
int E(int i) nogil:
|
||||
int E(int i) nogil const:
|
||||
if this._e_i <= 0 or this._e_i >= this.length:
|
||||
return 0
|
||||
if i < 0 or i >= this._e_i:
|
||||
return 0
|
||||
return this._ents[this._e_i - (i+1)].start
|
||||
|
||||
int L(int i, int idx) nogil:
|
||||
int L(int i, int idx) nogil const:
|
||||
if idx < 1:
|
||||
return -1
|
||||
if i < 0 or i >= this.length:
|
||||
|
@ -135,7 +135,7 @@ cdef cppclass StateC:
|
|||
ptr += 1
|
||||
return -1
|
||||
|
||||
int R(int i, int idx) nogil:
|
||||
int R(int i, int idx) nogil const:
|
||||
if idx < 1:
|
||||
return -1
|
||||
if i < 0 or i >= this.length:
|
||||
|
@ -159,39 +159,39 @@ cdef cppclass StateC:
|
|||
ptr -= 1
|
||||
return -1
|
||||
|
||||
bint empty() nogil:
|
||||
bint empty() nogil const:
|
||||
return this._s_i <= 0
|
||||
|
||||
bint eol() nogil:
|
||||
bint eol() nogil const:
|
||||
return this.buffer_length() == 0
|
||||
|
||||
bint at_break() nogil:
|
||||
bint at_break() nogil const:
|
||||
return this._break != -1
|
||||
|
||||
bint is_final() nogil:
|
||||
bint is_final() nogil const:
|
||||
return this.stack_depth() <= 0 and this._b_i >= this.length
|
||||
|
||||
bint has_head(int i) nogil:
|
||||
bint has_head(int i) nogil const:
|
||||
return this.safe_get(i).head != 0
|
||||
|
||||
int n_L(int i) nogil:
|
||||
int n_L(int i) nogil const:
|
||||
return this.safe_get(i).l_kids
|
||||
|
||||
int n_R(int i) nogil:
|
||||
int n_R(int i) nogil const:
|
||||
return this.safe_get(i).r_kids
|
||||
|
||||
bint stack_is_connected() nogil:
|
||||
bint stack_is_connected() nogil const:
|
||||
return False
|
||||
|
||||
bint entity_is_open() nogil:
|
||||
bint entity_is_open() nogil const:
|
||||
if this._e_i < 1:
|
||||
return False
|
||||
return this._ents[this._e_i-1].end == -1
|
||||
|
||||
int stack_depth() nogil:
|
||||
int stack_depth() nogil const:
|
||||
return this._s_i
|
||||
|
||||
int buffer_length() nogil:
|
||||
int buffer_length() nogil const:
|
||||
if this._break != -1:
|
||||
return this._break - this._b_i
|
||||
else:
|
||||
|
|
|
@ -17,6 +17,7 @@ from libc.string cimport memcpy
|
|||
|
||||
from cymem.cymem cimport Pool
|
||||
from .stateclass cimport StateClass
|
||||
from ._state cimport StateC
|
||||
|
||||
|
||||
DEF NON_MONOTONIC = True
|
||||
|
@ -57,7 +58,7 @@ cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) no
|
|||
cost += 1
|
||||
if gold.heads[S_i] == target and (NON_MONOTONIC or not stcls.has_head(S_i)):
|
||||
cost += 1
|
||||
cost += Break.is_valid(stcls, -1) and Break.move_cost(stcls, gold) == 0
|
||||
cost += Break.is_valid(stcls.c, -1) and Break.move_cost(stcls, gold) == 0
|
||||
return cost
|
||||
|
||||
|
||||
|
@ -70,7 +71,7 @@ cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nog
|
|||
cost += gold.heads[target] == B_i
|
||||
if gold.heads[B_i] == B_i or gold.heads[B_i] < target:
|
||||
break
|
||||
if Break.is_valid(stcls, -1) and Break.move_cost(stcls, gold) == 0:
|
||||
if Break.is_valid(stcls.c, -1) and Break.move_cost(stcls, gold) == 0:
|
||||
cost += 1
|
||||
return cost
|
||||
|
||||
|
@ -115,11 +116,11 @@ cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil:
|
|||
|
||||
cdef class Shift:
|
||||
@staticmethod
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
return st.buffer_length() >= 2 and not st.c.shifted[st.B(0)] and not st.B_(0).sent_start
|
||||
cdef bint is_valid(const StateC* st, int label) nogil:
|
||||
return st.buffer_length() >= 2 and not st.shifted[st.B(0)] and not st.B_(0).sent_start
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
cdef int transition(StateC* st, int label) nogil:
|
||||
st.push()
|
||||
st.fast_forward()
|
||||
|
||||
|
@ -138,11 +139,11 @@ cdef class Shift:
|
|||
|
||||
cdef class Reduce:
|
||||
@staticmethod
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
cdef bint is_valid(const StateC* st, int label) nogil:
|
||||
return st.stack_depth() >= 2
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
cdef int transition(StateC* st, int label) nogil:
|
||||
if st.has_head(st.S(0)):
|
||||
st.pop()
|
||||
else:
|
||||
|
@ -164,11 +165,11 @@ cdef class Reduce:
|
|||
|
||||
cdef class LeftArc:
|
||||
@staticmethod
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
cdef bint is_valid(const StateC* st, int label) nogil:
|
||||
return not st.B_(0).sent_start
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
cdef int transition(StateC* st, int label) nogil:
|
||||
st.add_arc(st.B(0), st.S(0), label)
|
||||
st.pop()
|
||||
st.fast_forward()
|
||||
|
@ -197,11 +198,11 @@ cdef class LeftArc:
|
|||
|
||||
cdef class RightArc:
|
||||
@staticmethod
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
cdef bint is_valid(const StateC* st, int label) nogil:
|
||||
return not st.B_(0).sent_start
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
cdef int transition(StateC* st, int label) nogil:
|
||||
st.add_arc(st.S(0), st.B(0), label)
|
||||
st.push()
|
||||
st.fast_forward()
|
||||
|
@ -226,7 +227,7 @@ cdef class RightArc:
|
|||
|
||||
cdef class Break:
|
||||
@staticmethod
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
cdef bint is_valid(const StateC* st, int label) nogil:
|
||||
cdef int i
|
||||
if not USE_BREAK:
|
||||
return False
|
||||
|
@ -243,7 +244,7 @@ cdef class Break:
|
|||
return True
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
cdef int transition(StateC* st, int label) nogil:
|
||||
st.set_break(st.B(0))
|
||||
st.fast_forward()
|
||||
|
||||
|
@ -396,11 +397,11 @@ cdef class ArcEager(TransitionSystem):
|
|||
|
||||
cdef int set_valid(self, int* output, StateClass stcls) nogil:
|
||||
cdef bint[N_MOVES] is_valid
|
||||
is_valid[SHIFT] = Shift.is_valid(stcls, -1)
|
||||
is_valid[REDUCE] = Reduce.is_valid(stcls, -1)
|
||||
is_valid[LEFT] = LeftArc.is_valid(stcls, -1)
|
||||
is_valid[RIGHT] = RightArc.is_valid(stcls, -1)
|
||||
is_valid[BREAK] = Break.is_valid(stcls, -1)
|
||||
is_valid[SHIFT] = Shift.is_valid(stcls.c, -1)
|
||||
is_valid[REDUCE] = Reduce.is_valid(stcls.c, -1)
|
||||
is_valid[LEFT] = LeftArc.is_valid(stcls.c, -1)
|
||||
is_valid[RIGHT] = RightArc.is_valid(stcls.c, -1)
|
||||
is_valid[BREAK] = Break.is_valid(stcls.c, -1)
|
||||
cdef int i
|
||||
for i in range(self.n_moves):
|
||||
output[i] = is_valid[self.c[i].move]
|
||||
|
@ -430,7 +431,7 @@ cdef class ArcEager(TransitionSystem):
|
|||
|
||||
n_gold = 0
|
||||
for i in range(self.n_moves):
|
||||
if self.c[i].is_valid(stcls, self.c[i].label):
|
||||
if self.c[i].is_valid(stcls.c, self.c[i].label):
|
||||
is_valid[i] = True
|
||||
move = self.c[i].move
|
||||
label = self.c[i].label
|
||||
|
|
|
@ -11,6 +11,7 @@ from ..gold cimport GoldParse
|
|||
from ..attrs cimport ENT_TYPE, ENT_IOB
|
||||
|
||||
from .stateclass cimport StateClass
|
||||
from ._state cimport StateC
|
||||
|
||||
|
||||
cdef enum:
|
||||
|
@ -150,11 +151,11 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
|
||||
cdef class Missing:
|
||||
@staticmethod
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
cdef bint is_valid(const StateC* st, int label) nogil:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateClass s, int label) nogil:
|
||||
cdef int transition(StateC* s, int label) nogil:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
|
@ -164,7 +165,7 @@ cdef class Missing:
|
|||
|
||||
cdef class Begin:
|
||||
@staticmethod
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
cdef bint is_valid(const StateC* st, int label) nogil:
|
||||
# Ensure we don't clobber preset entities. If no entity preset,
|
||||
# ent_iob is 0
|
||||
cdef int preset_ent_iob = st.B_(0).ent_iob
|
||||
|
@ -188,7 +189,7 @@ cdef class Begin:
|
|||
return label != 0 and not st.entity_is_open()
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
cdef int transition(StateC* st, int label) nogil:
|
||||
st.open_ent(label)
|
||||
st.set_ent_tag(st.B(0), 3, label)
|
||||
st.push()
|
||||
|
@ -214,7 +215,7 @@ cdef class Begin:
|
|||
|
||||
cdef class In:
|
||||
@staticmethod
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
cdef bint is_valid(const StateC* st, int label) nogil:
|
||||
cdef int preset_ent_iob = st.B_(0).ent_iob
|
||||
if preset_ent_iob == 2:
|
||||
return False
|
||||
|
@ -230,7 +231,7 @@ cdef class In:
|
|||
return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
cdef int transition(StateC* st, int label) nogil:
|
||||
st.set_ent_tag(st.B(0), 1, label)
|
||||
st.push()
|
||||
st.pop()
|
||||
|
@ -266,13 +267,13 @@ cdef class In:
|
|||
|
||||
cdef class Last:
|
||||
@staticmethod
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
cdef bint is_valid(const StateC* st, int label) nogil:
|
||||
if st.B_(1).ent_iob == 1:
|
||||
return False
|
||||
return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
cdef int transition(StateC* st, int label) nogil:
|
||||
st.close_ent()
|
||||
st.set_ent_tag(st.B(0), 1, label)
|
||||
st.push()
|
||||
|
@ -308,7 +309,7 @@ cdef class Last:
|
|||
|
||||
cdef class Unit:
|
||||
@staticmethod
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
cdef bint is_valid(const StateC* st, int label) nogil:
|
||||
cdef int preset_ent_iob = st.B_(0).ent_iob
|
||||
if preset_ent_iob == 2:
|
||||
return False
|
||||
|
@ -321,7 +322,7 @@ cdef class Unit:
|
|||
return label != 0 and not st.entity_is_open()
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
cdef int transition(StateC* st, int label) nogil:
|
||||
st.open_ent(label)
|
||||
st.close_ent()
|
||||
st.set_ent_tag(st.B(0), 3, label)
|
||||
|
@ -348,7 +349,7 @@ cdef class Unit:
|
|||
|
||||
cdef class Out:
|
||||
@staticmethod
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
cdef bint is_valid(const StateC* st, int label) nogil:
|
||||
cdef int preset_ent_iob = st.B_(0).ent_iob
|
||||
if preset_ent_iob == 3:
|
||||
return False
|
||||
|
@ -357,7 +358,7 @@ cdef class Out:
|
|||
return not st.entity_is_open()
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
cdef int transition(StateC* st, int label) nogil:
|
||||
st.set_ent_tag(st.B(0), 2, 0)
|
||||
st.push()
|
||||
st.pop()
|
||||
|
|
|
@ -124,7 +124,7 @@ cdef class Parser:
|
|||
with gil:
|
||||
move_name = self.moves.move_name(action.move, action.label)
|
||||
raise ValueError("Illegal action: %s" % move_name)
|
||||
action.do(stcls, action.label)
|
||||
action.do(stcls.c, action.label)
|
||||
memset(eg.c.scores, 0, sizeof(eg.c.scores[0]) * eg.c.nr_class)
|
||||
memset(eg.c.costs, 0, sizeof(eg.c.costs[0]) * eg.c.nr_class)
|
||||
for i in range(eg.c.nr_class):
|
||||
|
@ -151,7 +151,7 @@ cdef class Parser:
|
|||
guess = VecVec.arg_max_if_true(eg.c.scores, eg.c.is_valid, eg.c.nr_class)
|
||||
|
||||
action = self.moves.c[eg.guess]
|
||||
action.do(stcls, action.label)
|
||||
action.do(stcls.c, action.label)
|
||||
loss += eg.costs[eg.guess]
|
||||
eg.reset_classes(eg.nr_class)
|
||||
return loss
|
||||
|
@ -230,7 +230,7 @@ cdef class StepwiseState:
|
|||
action = self.parser.moves.c[clas]
|
||||
else:
|
||||
action = self.parser.moves.lookup_transition(action_name)
|
||||
action.do(self.stcls, action.label)
|
||||
action.do(self.stcls.c, action.label)
|
||||
|
||||
def finish(self):
|
||||
if self.stcls.is_final():
|
||||
|
|
|
@ -7,6 +7,7 @@ from ..gold cimport GoldParseC
|
|||
from ..strings cimport StringStore
|
||||
|
||||
from .stateclass cimport StateClass
|
||||
from ._state cimport StateC
|
||||
|
||||
|
||||
cdef struct Transition:
|
||||
|
@ -16,16 +17,16 @@ cdef struct Transition:
|
|||
|
||||
weight_t score
|
||||
|
||||
bint (*is_valid)(StateClass state, int label) nogil
|
||||
bint (*is_valid)(const StateC* state, int label) nogil
|
||||
weight_t (*get_cost)(StateClass state, const GoldParseC* gold, int label) nogil
|
||||
int (*do)(StateClass state, int label) nogil
|
||||
int (*do)(StateC* state, int label) nogil
|
||||
|
||||
|
||||
ctypedef weight_t (*get_cost_func_t)(StateClass state, const GoldParseC* gold, int label) nogil
|
||||
ctypedef weight_t (*move_cost_func_t)(StateClass state, const GoldParseC* gold) nogil
|
||||
ctypedef weight_t (*label_cost_func_t)(StateClass state, const GoldParseC* gold, int label) nogil
|
||||
|
||||
ctypedef int (*do_func_t)(StateClass state, int label) nogil
|
||||
ctypedef int (*do_func_t)(StateC* state, int label) nogil
|
||||
|
||||
|
||||
cdef class TransitionSystem:
|
||||
|
|
|
@ -64,12 +64,12 @@ cdef class TransitionSystem:
|
|||
|
||||
def is_valid(self, StateClass stcls, move_name):
|
||||
action = self.lookup_transition(move_name)
|
||||
return action.is_valid(stcls, action.label)
|
||||
return action.is_valid(stcls.c, action.label)
|
||||
|
||||
cdef int set_valid(self, int* is_valid, StateClass stcls) nogil:
|
||||
cdef int i
|
||||
for i in range(self.n_moves):
|
||||
is_valid[i] = self.c[i].is_valid(stcls, self.c[i].label)
|
||||
is_valid[i] = self.c[i].is_valid(stcls.c, self.c[i].label)
|
||||
|
||||
cdef int set_costs(self, int* is_valid, weight_t* costs,
|
||||
StateClass stcls, GoldParse gold) except -1:
|
||||
|
|
Loading…
Reference in New Issue
Block a user