* Move StateClass into interface of transition functions

This commit is contained in:
Matthew Honnibal 2015-06-10 01:35:28 +02:00
parent 4b98b3e9c8
commit d68c686ec1
6 changed files with 86 additions and 76 deletions

View File

@ -120,11 +120,11 @@ cdef class Shift:
return not st.eol() return not st.eol()
@staticmethod @staticmethod
cdef int transition(State* state, int label) except -1: cdef int transition(StateClass state, int label) except -1:
# Set the dep label, in case we need it after we reduce # Set the dep label, in case we need it after we reduce
if NON_MONOTONIC: if NON_MONOTONIC:
state.sent[state.i].dep = label state._sent[state.B(0)].dep = label
push_stack(state) state.push()
@staticmethod @staticmethod
cdef int cost(StateClass st, const GoldParseC* gold, int label) except -1: cdef int cost(StateClass st, const GoldParseC* gold, int label) except -1:
@ -148,10 +148,10 @@ cdef class Reduce:
return st.stack_depth() >= 2 and st.has_head(st.S(0)) return st.stack_depth() >= 2 and st.has_head(st.S(0))
@staticmethod @staticmethod
cdef int transition(State* state, int label) except -1: cdef int transition(StateClass st, int label) except -1:
if NON_MONOTONIC and not has_head(get_s0(state)) and state.stack_len >= 2: if NON_MONOTONIC and not st.has_head(st.S(0)) and st.stack_depth() >= 2:
add_dep(state, state.stack[-1], state.stack[0], get_s0(state).dep) st.add_arc(st.S(1), st.S(0), st.S_(0).dep)
pop_stack(state) st.pop()
@staticmethod @staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1: cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
@ -178,13 +178,13 @@ cdef class LeftArc:
return st.stack_depth() >= 1 and not st.has_head(st.S(0)) return st.stack_depth() >= 1 and not st.has_head(st.S(0))
@staticmethod @staticmethod
cdef int transition(State* state, int label) except -1: cdef int transition(StateClass st, int label) except -1:
# Interpret left-arcs from EOL as attachment to root # Interpret left-arcs from EOL as attachment to root
if at_eol(state): if st.eol():
add_dep(state, state.stack[0], state.stack[0], label) st.add_arc(st.S(0), st.S(0), label)
else: else:
add_dep(state, state.i, state.stack[0], label) st.add_arc(st.B(0), st.S(0), label)
pop_stack(state) st.pop()
@staticmethod @staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1: cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
@ -208,9 +208,9 @@ cdef class RightArc:
return st.stack_depth() >= 1 and not st.eol() return st.stack_depth() >= 1 and not st.eol()
@staticmethod @staticmethod
cdef int transition(State* state, int label) except -1: cdef int transition(StateClass st, int label) except -1:
add_dep(state, state.stack[0], state.i, label) st.add_arc(st.S(0), st.B(0), label)
push_stack(state) st.push()
@staticmethod @staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1: cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
@ -256,13 +256,12 @@ cdef class Break:
return True return True
@staticmethod @staticmethod
cdef int transition(State* state, int label) except -1: cdef int transition(StateClass st, int label) except -1:
state.sent[state.i-1].sent_end = True st.set_sent_end(st.B(0)-1)
while state.stack_len != 0: while not st.empty():
if get_s0(state).head == 0: if not st.has_head(st.S(0)):
get_s0(state).dep = label st._sent[st.S(0)].dep = label
state.stack -= 1 st.pop()
state.stack_len -= 1
@staticmethod @staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1: cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
@ -370,11 +369,11 @@ cdef class ArcEager(TransitionSystem):
cdef int initialize_state(self, State* state) except -1: cdef int initialize_state(self, State* state) except -1:
push_stack(state) push_stack(state)
cdef int finalize_state(self, State* state) except -1: cdef int finalize_state(self, StateClass st) except -1:
cdef int root_label = self.strings['ROOT'] cdef int root_label = self.strings['ROOT']
for i in range(state.sent_len): for i in range(st.length):
if state.sent[i].head == 0 and state.sent[i].dep == 0: if st._sent[i].head == 0 and st._sent[i].dep == 0:
state.sent[i].dep = root_label st._sent[i].dep = root_label
cdef int set_valid(self, bint* output, StateClass stcls) except -1: cdef int set_valid(self, bint* output, StateClass stcls) except -1:
cdef bint[N_MOVES] is_valid cdef bint[N_MOVES] is_valid

View File

@ -158,7 +158,7 @@ cdef class Missing:
return False return False
@staticmethod @staticmethod
cdef int transition(State* s, int label) except -1: cdef int transition(StateClass s, int label) except -1:
raise NotImplementedError raise NotImplementedError
@staticmethod @staticmethod
@ -172,15 +172,11 @@ cdef class Begin:
return label != 0 and not st.entity_is_open() return label != 0 and not st.entity_is_open()
@staticmethod @staticmethod
cdef int transition(State* s, int label) except -1: cdef int transition(StateClass st, int label) except -1:
s.ent += 1 st.open_ent(label)
s.ents_len += 1 st.set_ent_tag(st.B(0), 3, label)
s.ent.start = s.i st.push()
s.ent.label = label st.pop()
s.ent.end = 0
s.sent[s.i].ent_iob = 3
s.sent[s.i].ent_type = label
s.i += 1
@staticmethod @staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1: cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
@ -206,10 +202,10 @@ cdef class In:
return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label
@staticmethod @staticmethod
cdef int transition(State* s, int label) except -1: cdef int transition(StateClass st, int label) except -1:
s.sent[s.i].ent_iob = 1 st.set_ent_tag(st.B(0), 1, label)
s.sent[s.i].ent_type = label st.push()
s.i += 1 st.pop()
@staticmethod @staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1: cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
@ -246,11 +242,10 @@ cdef class Last:
return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label
@staticmethod @staticmethod
cdef int transition(State* s, int label) except -1: cdef int transition(StateClass st, int label) except -1:
s.ent.end = s.i+1 st.close_ent()
s.sent[s.i].ent_iob = 1 st.push()
s.sent[s.i].ent_type = label st.pop()
s.i += 1
@staticmethod @staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1: cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
@ -286,15 +281,12 @@ cdef class Unit:
return label != 0 and not st.entity_is_open() return label != 0 and not st.entity_is_open()
@staticmethod @staticmethod
cdef int transition(State* s, int label) except -1: cdef int transition(StateClass st, int label) except -1:
s.ent += 1 st.open_ent(label)
s.ents_len += 1 st.close_ent()
s.ent.start = s.i st.set_ent_tag(st.B(0), 3, label)
s.ent.label = label st.push()
s.ent.end = s.i+1 st.pop()
s.sent[s.i].ent_iob = 3
s.sent[s.i].ent_type = label
s.i += 1
@staticmethod @staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1: cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
@ -320,9 +312,10 @@ cdef class Out:
return not st.entity_is_open() return not st.entity_is_open()
@staticmethod @staticmethod
cdef int transition(State* s, int label) except -1: cdef int transition(StateClass st, int label) except -1:
s.sent[s.i].ent_iob = 2 st.set_ent_tag(st.B(0), 2, 0)
s.i += 1 st.push()
st.pop()
@staticmethod @staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1: cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:

View File

@ -106,15 +106,17 @@ cdef class Parser:
cdef State* state = new_state(mem, tokens.data, tokens.length) cdef State* state = new_state(mem, tokens.data, tokens.length)
self.moves.initialize_state(state) self.moves.initialize_state(state)
cdef StateClass stcls = StateClass(state.sent_len) cdef StateClass stcls = StateClass(state.sent_len)
stcls.from_struct(state)
cdef Transition guess cdef Transition guess
while not is_final(state): words = [w.orth_ for w in tokens]
stcls.from_struct(state) while not stcls.is_final():
#print stcls.print_state(words)
_new_fill_context(context, stcls) _new_fill_context(context, stcls)
scores = self.model.score(context) scores = self.model.score(context)
guess = self.moves.best_valid(scores, stcls) guess = self.moves.best_valid(scores, stcls)
guess.do(state, guess.label) guess.do(stcls, guess.label)
self.moves.finalize_state(state) self.moves.finalize_state(stcls)
tokens.set_parse(state.sent) tokens.set_parse(stcls._sent)
cdef int _beam_parse(self, Tokens tokens) except -1: cdef int _beam_parse(self, Tokens tokens) except -1:
cdef Beam beam = Beam(self.moves.n_moves, self.cfg.beam_width) cdef Beam beam = Beam(self.moves.n_moves, self.cfg.beam_width)
@ -123,8 +125,9 @@ cdef class Parser:
while not beam.is_done: while not beam.is_done:
self._advance_beam(beam, None, False) self._advance_beam(beam, None, False)
state = <State*>beam.at(0) state = <State*>beam.at(0)
self.moves.finalize_state(state) #self.moves.finalize_state(state)
tokens.set_parse(state.sent) #tokens.set_parse(state.sent)
raise Exception
def _greedy_train(self, Tokens tokens, GoldParse gold): def _greedy_train(self, Tokens tokens, GoldParse gold):
cdef Pool mem = Pool() cdef Pool mem = Pool()
@ -137,17 +140,18 @@ cdef class Parser:
cdef Transition guess cdef Transition guess
cdef Transition best cdef Transition best
cdef StateClass stcls = StateClass(state.sent_len) cdef StateClass stcls = StateClass(state.sent_len)
stcls.from_struct(state)
cdef atom_t[CONTEXT_SIZE] context cdef atom_t[CONTEXT_SIZE] context
loss = 0 loss = 0
while not is_final(state): words = [w.orth_ for w in tokens]
stcls.from_struct(state) while not stcls.is_final():
_new_fill_context(context, stcls) _new_fill_context(context, stcls)
scores = self.model.score(context) scores = self.model.score(context)
guess = self.moves.best_valid(scores, stcls) guess = self.moves.best_valid(scores, stcls)
best = self.moves.best_gold(scores, stcls, gold) best = self.moves.best_gold(scores, stcls, gold)
cost = guess.get_cost(stcls, &gold.c, guess.label) cost = guess.get_cost(stcls, &gold.c, guess.label)
self.model.update(context, guess.clas, best.clas, cost) self.model.update(context, guess.clas, best.clas, cost)
guess.do(state, guess.label) guess.do(stcls, guess.label)
loss += cost loss += cost
return loss return loss
@ -203,14 +207,16 @@ cdef class Parser:
cdef Pool mem = Pool() cdef Pool mem = Pool()
cdef State* state = new_state(mem, tokens.data, tokens.length) cdef State* state = new_state(mem, tokens.data, tokens.length)
self.moves.initialize_state(state) self.moves.initialize_state(state)
cdef StateClass stcls = StateClass(state.sent_len)
stcls.from_struct(state)
cdef class_t clas cdef class_t clas
cdef int n_feats cdef int n_feats
for clas in hist: for clas in hist:
fill_context(context, state) _new_fill_context(context, stcls)
feats = self.model._extractor.get_feats(context, &n_feats) feats = self.model._extractor.get_feats(context, &n_feats)
count_feats(counts[clas], feats, n_feats, inc) count_feats(counts[clas], feats, n_feats, inc)
self.moves.c[clas].do(state, self.moves.c[clas].label) self.moves.c[clas].do(stcls, self.moves.c[clas].label)
# These are passed as callbacks to thinc.search.Beam # These are passed as callbacks to thinc.search.Beam
@ -220,7 +226,8 @@ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves)
src = <const State*>_src src = <const State*>_src
moves = <const Transition*>_moves moves = <const Transition*>_moves
copy_state(dest, src) copy_state(dest, src)
moves[clas].do(dest, moves[clas].label) raise Exception
#moves[clas].do(dest, moves[clas].label)
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL: cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:

View File

@ -126,7 +126,7 @@ cdef class StateClass:
return self._b_i >= self.length return self._b_i >= self.length
cdef bint is_final(self) nogil: cdef bint is_final(self) nogil:
return self.eol() and self.empty() return self.eol() and self.stack_depth() <= 1
cdef bint has_head(self, int i) nogil: cdef bint has_head(self, int i) nogil:
return self.safe_get(i).head != 0 return self.safe_get(i).head != 0
@ -196,7 +196,7 @@ cdef class StateClass:
self._sent[i].ent_type = ent_type self._sent[i].ent_type = ent_type
cdef void set_sent_end(self, int i) nogil: cdef void set_sent_end(self, int i) nogil:
if 0 < i < self.length: if 0 <= i < self.length:
self._sent[i].sent_end = True self._sent[i].sent_end = True
cdef void clone(self, StateClass src) nogil: cdef void clone(self, StateClass src) nogil:
@ -207,6 +207,17 @@ cdef class StateClass:
self._b_i = src._b_i self._b_i = src._b_i
self._s_i = src._s_i self._s_i = src._s_i
self._e_i = src._e_i self._e_i = src._e_i
def print_state(self, words):
words = list(words) + ['_']
top = words[self.S(0)] + '_%d' % self.H(self.S(0))
second = words[self.S(1)] + '_%d' % self.H(self.S(1))
third = words[self.S(2)] + '_%d' % self.H(self.S(2))
n0 = words[self.B(0)]
n1 = words[self.B(1)]
return ' '.join((str(self.stack_depth()), third, second, top, '|', n0, n1))
# From https://en.wikipedia.org/wiki/Hamming_weight # From https://en.wikipedia.org/wiki/Hamming_weight

View File

@ -19,14 +19,14 @@ cdef struct Transition:
bint (*is_valid)(StateClass state, int label) except -1 bint (*is_valid)(StateClass state, int label) except -1
int (*get_cost)(StateClass state, const GoldParseC* gold, int label) except -1 int (*get_cost)(StateClass state, const GoldParseC* gold, int label) except -1
int (*do)(State* state, int label) except -1 int (*do)(StateClass state, int label) except -1
ctypedef int (*get_cost_func_t)(StateClass state, const GoldParseC* gold, int label) except -1 ctypedef int (*get_cost_func_t)(StateClass state, const GoldParseC* gold, int label) except -1
ctypedef int (*move_cost_func_t)(StateClass state, const GoldParseC* gold) except -1 ctypedef int (*move_cost_func_t)(StateClass state, const GoldParseC* gold) except -1
ctypedef int (*label_cost_func_t)(StateClass state, const GoldParseC* gold, int label) except -1 ctypedef int (*label_cost_func_t)(StateClass state, const GoldParseC* gold, int label) except -1
ctypedef int (*do_func_t)(State* state, int label) except -1 ctypedef int (*do_func_t)(StateClass state, int label) except -1
cdef class TransitionSystem: cdef class TransitionSystem:
@ -37,7 +37,7 @@ cdef class TransitionSystem:
cdef readonly int n_moves cdef readonly int n_moves
cdef int initialize_state(self, State* state) except -1 cdef int initialize_state(self, State* state) except -1
cdef int finalize_state(self, State* state) except -1 cdef int finalize_state(self, StateClass state) except -1
cdef int preprocess_gold(self, GoldParse gold) except -1 cdef int preprocess_gold(self, GoldParse gold) except -1

View File

@ -32,7 +32,7 @@ cdef class TransitionSystem:
cdef int initialize_state(self, State* state) except -1: cdef int initialize_state(self, State* state) except -1:
pass pass
cdef int finalize_state(self, State* state) except -1: cdef int finalize_state(self, StateClass state) except -1:
pass pass
cdef int preprocess_gold(self, GoldParse gold) except -1: cdef int preprocess_gold(self, GoldParse gold) except -1: