* Remove gil from parser.call

This commit is contained in:
Matthew Honnibal 2015-07-14 23:47:03 +02:00
parent 3c1e3e9ee8
commit 9a8db9743c
8 changed files with 36 additions and 33 deletions

View File

@ -24,7 +24,7 @@ cdef class Model:
cdef readonly int n_feats cdef readonly int n_feats
cdef const weight_t* score(self, atom_t* context) except NULL cdef const weight_t* score(self, atom_t* context) except NULL
cdef int set_scores(self, weight_t* scores, atom_t* context) except -1 cdef int set_scores(self, weight_t* scores, atom_t* context) nogil
cdef int update(self, atom_t* context, class_t guess, class_t gold, int cost) except -1 cdef int update(self, atom_t* context, class_t guess, class_t gold, int cost) except -1

View File

@ -75,7 +75,7 @@ cdef class Model:
feats = self._extractor.get_feats(context, &n_feats) feats = self._extractor.get_feats(context, &n_feats)
return self._model.get_scores(feats, n_feats) return self._model.get_scores(feats, n_feats)
cdef int set_scores(self, weight_t* scores, atom_t* context) except -1: cdef int set_scores(self, weight_t* scores, atom_t* context) nogil:
cdef int n_feats cdef int n_feats
feats = self._extractor.get_feats(context, &n_feats) feats = self._extractor.get_feats(context, &n_feats)
self._model.set_scores(scores, feats, n_feats) self._model.set_scores(scores, feats, n_feats)

View File

@ -374,17 +374,16 @@ cdef class ArcEager(TransitionSystem):
st._sent[i].r_edge = i st._sent[i].r_edge = i
st.fast_forward() st.fast_forward()
cdef int finalize_state(self, StateClass st) except -1: cdef int finalize_state(self, StateClass st) nogil:
cdef int root_label = self.strings['ROOT']
for i in range(st.length): for i in range(st.length):
if st._sent[i].head == 0 and st._sent[i].dep == 0: if st._sent[i].head == 0 and st._sent[i].dep == 0:
st._sent[i].dep = root_label st._sent[i].dep = self.root_label
# If we're not using the Break transition, we segment via root-labelled # If we're not using the Break transition, we segment via root-labelled
# arcs between the root words. # arcs between the root words.
elif USE_ROOT_ARC_SEGMENT and st._sent[i].dep == root_label: elif USE_ROOT_ARC_SEGMENT and st._sent[i].dep == self.root_label:
st._sent[i].head = 0 st._sent[i].head = 0
cdef int set_valid(self, bint* output, StateClass stcls) except -1: cdef int set_valid(self, bint* output, StateClass stcls) nogil:
cdef bint[N_MOVES] is_valid cdef bint[N_MOVES] is_valid
is_valid[SHIFT] = Shift.is_valid(stcls, -1) is_valid[SHIFT] = Shift.is_valid(stcls, -1)
is_valid[REDUCE] = Reduce.is_valid(stcls, -1) is_valid[REDUCE] = Reduce.is_valid(stcls, -1)
@ -392,11 +391,8 @@ cdef class ArcEager(TransitionSystem):
is_valid[RIGHT] = RightArc.is_valid(stcls, -1) is_valid[RIGHT] = RightArc.is_valid(stcls, -1)
is_valid[BREAK] = Break.is_valid(stcls, -1) is_valid[BREAK] = Break.is_valid(stcls, -1)
cdef int i cdef int i
n_valid = 0
for i in range(self.n_moves): for i in range(self.n_moves):
output[i] = is_valid[self.c[i].move] output[i] = is_valid[self.c[i].move]
n_valid += output[i]
assert n_valid >= 1
cdef int set_costs(self, bint* is_valid, int* costs, cdef int set_costs(self, bint* is_valid, int* costs,
StateClass stcls, GoldParse gold) except -1: StateClass stcls, GoldParse gold) except -1:

View File

@ -6,9 +6,13 @@ from .arc_eager cimport TransitionSystem
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
from ..structs cimport TokenC from ..structs cimport TokenC
from thinc.api cimport Example, ExampleC
from .stateclass cimport StateClass
cdef class Parser: cdef class Parser:
cdef readonly object cfg cdef readonly object cfg
cdef readonly Model model cdef readonly Model model
cdef readonly TransitionSystem moves cdef readonly TransitionSystem moves
cdef void parse(self, StateClass stcls, ExampleC eg) nogil

View File

@ -22,7 +22,7 @@ from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t
from util import Config from util import Config
from thinc.api cimport Example from thinc.api cimport Example, ExampleC
from ..structs cimport TokenC from ..structs cimport TokenC
@ -41,6 +41,8 @@ from ._parse_features cimport CONTEXT_SIZE
from ._parse_features cimport fill_context from ._parse_features cimport fill_context
from .stateclass cimport StateClass from .stateclass cimport StateClass
from .._ml cimport arg_max_if_true
DEBUG = False DEBUG = False
def set_debug(val): def set_debug(val):
@ -65,6 +67,9 @@ def ParserFactory(transition_system):
return lambda strings, dir_: Parser(strings, dir_, transition_system) return lambda strings, dir_: Parser(strings, dir_, transition_system)
DEF stack_class_alloc_limit = 256
cdef class Parser: cdef class Parser:
def __init__(self, StringStore strings, model_dir, transition_system): def __init__(self, StringStore strings, model_dir, transition_system):
if not os.path.exists(model_dir): if not os.path.exists(model_dir):
@ -83,18 +88,22 @@ cdef class Parser:
cdef Example eg = Example(self.model.n_classes, CONTEXT_SIZE, cdef Example eg = Example(self.model.n_classes, CONTEXT_SIZE,
self.model.n_feats, self.model.n_feats) self.model.n_feats, self.model.n_feats)
while not stcls.is_final(): with nogil:
memset(eg.c.scores, 0, eg.c.nr_class * sizeof(weight_t)) self.parse(stcls, eg.c)
self.moves.set_valid(<bint*>eg.c.is_valid, stcls)
fill_context(eg.c.atoms, stcls)
self.model.predict(eg)
self.moves.c[eg.c.guess].do(stcls, self.moves.c[eg.c.guess].label)
self.moves.finalize_state(stcls)
tokens.set_parse(stcls._sent) tokens.set_parse(stcls._sent)
cdef void parse(self, StateClass stcls, ExampleC eg) nogil:
while not stcls.is_final():
memset(eg.scores, 0, eg.nr_class * sizeof(weight_t))
self.moves.set_valid(<bint*>eg.is_valid, stcls)
fill_context(eg.atoms, stcls)
self.model.set_scores(eg.scores, eg.atoms)
eg.guess = arg_max_if_true(eg.scores, eg.is_valid, self.model.n_classes)
self.moves.c[eg.guess].do(stcls, self.moves.c[eg.guess].label)
self.moves.finalize_state(stcls)
def train(self, Doc tokens, GoldParse gold): def train(self, Doc tokens, GoldParse gold):
self.moves.preprocess_gold(gold) self.moves.preprocess_gold(gold)
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length) cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)

View File

@ -52,11 +52,7 @@ cdef class StateClass:
cdef const TokenC* target = &self._sent[i] cdef const TokenC* target = &self._sent[i]
if target.l_kids < idx: if target.l_kids < idx:
return -1 return -1
<<<<<<< HEAD
cdef const TokenC* ptr = &self._sent[target.l_edge] cdef const TokenC* ptr = &self._sent[target.l_edge]
=======
cdef const TokenC* ptr = self._sent
>>>>>>> neuralnet
while ptr < target: while ptr < target:
# If this head is still to the right of us, we can skip to it # If this head is still to the right of us, we can skip to it
@ -82,11 +78,7 @@ cdef class StateClass:
cdef const TokenC* target = &self._sent[i] cdef const TokenC* target = &self._sent[i]
if target.r_kids < idx: if target.r_kids < idx:
return -1 return -1
<<<<<<< HEAD
cdef const TokenC* ptr = &self._sent[target.r_edge] cdef const TokenC* ptr = &self._sent[target.r_edge]
=======
cdef const TokenC* ptr = self._sent + (self.length - 1)
>>>>>>> neuralnet
while ptr > target: while ptr > target:
# If this head is still to the right of us, we can skip to it # If this head is still to the right of us, we can skip to it
# No token that's between this token and this head could be our # No token that's between this token and this head could be our

View File

@ -34,9 +34,10 @@ cdef class TransitionSystem:
cdef const Transition* c cdef const Transition* c
cdef bint* _is_valid cdef bint* _is_valid
cdef readonly int n_moves cdef readonly int n_moves
cdef public int root_label
cdef int initialize_state(self, StateClass state) except -1 cdef int initialize_state(self, StateClass state) except -1
cdef int finalize_state(self, StateClass state) except -1 cdef int finalize_state(self, StateClass state) nogil
cdef int preprocess_gold(self, GoldParse gold) except -1 cdef int preprocess_gold(self, GoldParse gold) except -1
@ -44,7 +45,7 @@ cdef class TransitionSystem:
cdef Transition init_transition(self, int clas, int move, int label) except * cdef Transition init_transition(self, int clas, int move, int label) except *
cdef int set_valid(self, bint* output, StateClass state) except -1 cdef int set_valid(self, bint* output, StateClass state) nogil
cdef int set_costs(self, bint* is_valid, int* costs, cdef int set_costs(self, bint* is_valid, int* costs,
StateClass state, GoldParse gold) except -1 StateClass state, GoldParse gold) except -1

View File

@ -27,11 +27,12 @@ cdef class TransitionSystem:
moves[i] = self.init_transition(i, int(action), label_id) moves[i] = self.init_transition(i, int(action), label_id)
i += 1 i += 1
self.c = moves self.c = moves
self.root_label = self.strings['ROOT']
cdef int initialize_state(self, StateClass state) except -1: cdef int initialize_state(self, StateClass state) except -1:
pass pass
cdef int finalize_state(self, StateClass state) except -1: cdef int finalize_state(self, StateClass state) nogil:
pass pass
cdef int preprocess_gold(self, GoldParse gold) except -1: cdef int preprocess_gold(self, GoldParse gold) except -1:
@ -43,7 +44,7 @@ cdef class TransitionSystem:
cdef Transition init_transition(self, int clas, int move, int label) except *: cdef Transition init_transition(self, int clas, int move, int label) except *:
raise NotImplementedError raise NotImplementedError
cdef int set_valid(self, bint* is_valid, StateClass stcls) except -1: cdef int set_valid(self, bint* is_valid, StateClass stcls) nogil:
cdef int i cdef int i
for i in range(self.n_moves): for i in range(self.n_moves):
is_valid[i] = self.c[i].is_valid(stcls, self.c[i].label) is_valid[i] = self.c[i].is_valid(stcls, self.c[i].label)