mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-27 20:33:42 +03:00
* Move StateClass into interface of transition functions
This commit is contained in:
parent
4b98b3e9c8
commit
d68c686ec1
|
@ -120,11 +120,11 @@ cdef class Shift:
|
||||||
return not st.eol()
|
return not st.eol()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(State* state, int label) except -1:
|
cdef int transition(StateClass state, int label) except -1:
|
||||||
# Set the dep label, in case we need it after we reduce
|
# Set the dep label, in case we need it after we reduce
|
||||||
if NON_MONOTONIC:
|
if NON_MONOTONIC:
|
||||||
state.sent[state.i].dep = label
|
state._sent[state.B(0)].dep = label
|
||||||
push_stack(state)
|
state.push()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int cost(StateClass st, const GoldParseC* gold, int label) except -1:
|
cdef int cost(StateClass st, const GoldParseC* gold, int label) except -1:
|
||||||
|
@ -148,10 +148,10 @@ cdef class Reduce:
|
||||||
return st.stack_depth() >= 2 and st.has_head(st.S(0))
|
return st.stack_depth() >= 2 and st.has_head(st.S(0))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(State* state, int label) except -1:
|
cdef int transition(StateClass st, int label) except -1:
|
||||||
if NON_MONOTONIC and not has_head(get_s0(state)) and state.stack_len >= 2:
|
if NON_MONOTONIC and not st.has_head(st.S(0)) and st.stack_depth() >= 2:
|
||||||
add_dep(state, state.stack[-1], state.stack[0], get_s0(state).dep)
|
st.add_arc(st.S(1), st.S(0), st.S_(0).dep)
|
||||||
pop_stack(state)
|
st.pop()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||||
|
@ -178,13 +178,13 @@ cdef class LeftArc:
|
||||||
return st.stack_depth() >= 1 and not st.has_head(st.S(0))
|
return st.stack_depth() >= 1 and not st.has_head(st.S(0))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(State* state, int label) except -1:
|
cdef int transition(StateClass st, int label) except -1:
|
||||||
# Interpret left-arcs from EOL as attachment to root
|
# Interpret left-arcs from EOL as attachment to root
|
||||||
if at_eol(state):
|
if st.eol():
|
||||||
add_dep(state, state.stack[0], state.stack[0], label)
|
st.add_arc(st.S(0), st.S(0), label)
|
||||||
else:
|
else:
|
||||||
add_dep(state, state.i, state.stack[0], label)
|
st.add_arc(st.B(0), st.S(0), label)
|
||||||
pop_stack(state)
|
st.pop()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||||
|
@ -208,9 +208,9 @@ cdef class RightArc:
|
||||||
return st.stack_depth() >= 1 and not st.eol()
|
return st.stack_depth() >= 1 and not st.eol()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(State* state, int label) except -1:
|
cdef int transition(StateClass st, int label) except -1:
|
||||||
add_dep(state, state.stack[0], state.i, label)
|
st.add_arc(st.S(0), st.B(0), label)
|
||||||
push_stack(state)
|
st.push()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||||
|
@ -256,13 +256,12 @@ cdef class Break:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(State* state, int label) except -1:
|
cdef int transition(StateClass st, int label) except -1:
|
||||||
state.sent[state.i-1].sent_end = True
|
st.set_sent_end(st.B(0)-1)
|
||||||
while state.stack_len != 0:
|
while not st.empty():
|
||||||
if get_s0(state).head == 0:
|
if not st.has_head(st.S(0)):
|
||||||
get_s0(state).dep = label
|
st._sent[st.S(0)].dep = label
|
||||||
state.stack -= 1
|
st.pop()
|
||||||
state.stack_len -= 1
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||||
|
@ -370,11 +369,11 @@ cdef class ArcEager(TransitionSystem):
|
||||||
cdef int initialize_state(self, State* state) except -1:
|
cdef int initialize_state(self, State* state) except -1:
|
||||||
push_stack(state)
|
push_stack(state)
|
||||||
|
|
||||||
cdef int finalize_state(self, State* state) except -1:
|
cdef int finalize_state(self, StateClass st) except -1:
|
||||||
cdef int root_label = self.strings['ROOT']
|
cdef int root_label = self.strings['ROOT']
|
||||||
for i in range(state.sent_len):
|
for i in range(st.length):
|
||||||
if state.sent[i].head == 0 and state.sent[i].dep == 0:
|
if st._sent[i].head == 0 and st._sent[i].dep == 0:
|
||||||
state.sent[i].dep = root_label
|
st._sent[i].dep = root_label
|
||||||
|
|
||||||
cdef int set_valid(self, bint* output, StateClass stcls) except -1:
|
cdef int set_valid(self, bint* output, StateClass stcls) except -1:
|
||||||
cdef bint[N_MOVES] is_valid
|
cdef bint[N_MOVES] is_valid
|
||||||
|
|
|
@ -158,7 +158,7 @@ cdef class Missing:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(State* s, int label) except -1:
|
cdef int transition(StateClass s, int label) except -1:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -172,15 +172,11 @@ cdef class Begin:
|
||||||
return label != 0 and not st.entity_is_open()
|
return label != 0 and not st.entity_is_open()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(State* s, int label) except -1:
|
cdef int transition(StateClass st, int label) except -1:
|
||||||
s.ent += 1
|
st.open_ent(label)
|
||||||
s.ents_len += 1
|
st.set_ent_tag(st.B(0), 3, label)
|
||||||
s.ent.start = s.i
|
st.push()
|
||||||
s.ent.label = label
|
st.pop()
|
||||||
s.ent.end = 0
|
|
||||||
s.sent[s.i].ent_iob = 3
|
|
||||||
s.sent[s.i].ent_type = label
|
|
||||||
s.i += 1
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||||
|
@ -206,10 +202,10 @@ cdef class In:
|
||||||
return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label
|
return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(State* s, int label) except -1:
|
cdef int transition(StateClass st, int label) except -1:
|
||||||
s.sent[s.i].ent_iob = 1
|
st.set_ent_tag(st.B(0), 1, label)
|
||||||
s.sent[s.i].ent_type = label
|
st.push()
|
||||||
s.i += 1
|
st.pop()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||||
|
@ -246,11 +242,10 @@ cdef class Last:
|
||||||
return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label
|
return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(State* s, int label) except -1:
|
cdef int transition(StateClass st, int label) except -1:
|
||||||
s.ent.end = s.i+1
|
st.close_ent()
|
||||||
s.sent[s.i].ent_iob = 1
|
st.push()
|
||||||
s.sent[s.i].ent_type = label
|
st.pop()
|
||||||
s.i += 1
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||||
|
@ -286,15 +281,12 @@ cdef class Unit:
|
||||||
return label != 0 and not st.entity_is_open()
|
return label != 0 and not st.entity_is_open()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(State* s, int label) except -1:
|
cdef int transition(StateClass st, int label) except -1:
|
||||||
s.ent += 1
|
st.open_ent(label)
|
||||||
s.ents_len += 1
|
st.close_ent()
|
||||||
s.ent.start = s.i
|
st.set_ent_tag(st.B(0), 3, label)
|
||||||
s.ent.label = label
|
st.push()
|
||||||
s.ent.end = s.i+1
|
st.pop()
|
||||||
s.sent[s.i].ent_iob = 3
|
|
||||||
s.sent[s.i].ent_type = label
|
|
||||||
s.i += 1
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||||
|
@ -320,9 +312,10 @@ cdef class Out:
|
||||||
return not st.entity_is_open()
|
return not st.entity_is_open()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(State* s, int label) except -1:
|
cdef int transition(StateClass st, int label) except -1:
|
||||||
s.sent[s.i].ent_iob = 2
|
st.set_ent_tag(st.B(0), 2, 0)
|
||||||
s.i += 1
|
st.push()
|
||||||
|
st.pop()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||||
|
|
|
@ -106,15 +106,17 @@ cdef class Parser:
|
||||||
cdef State* state = new_state(mem, tokens.data, tokens.length)
|
cdef State* state = new_state(mem, tokens.data, tokens.length)
|
||||||
self.moves.initialize_state(state)
|
self.moves.initialize_state(state)
|
||||||
cdef StateClass stcls = StateClass(state.sent_len)
|
cdef StateClass stcls = StateClass(state.sent_len)
|
||||||
cdef Transition guess
|
|
||||||
while not is_final(state):
|
|
||||||
stcls.from_struct(state)
|
stcls.from_struct(state)
|
||||||
|
cdef Transition guess
|
||||||
|
words = [w.orth_ for w in tokens]
|
||||||
|
while not stcls.is_final():
|
||||||
|
#print stcls.print_state(words)
|
||||||
_new_fill_context(context, stcls)
|
_new_fill_context(context, stcls)
|
||||||
scores = self.model.score(context)
|
scores = self.model.score(context)
|
||||||
guess = self.moves.best_valid(scores, stcls)
|
guess = self.moves.best_valid(scores, stcls)
|
||||||
guess.do(state, guess.label)
|
guess.do(stcls, guess.label)
|
||||||
self.moves.finalize_state(state)
|
self.moves.finalize_state(stcls)
|
||||||
tokens.set_parse(state.sent)
|
tokens.set_parse(stcls._sent)
|
||||||
|
|
||||||
cdef int _beam_parse(self, Tokens tokens) except -1:
|
cdef int _beam_parse(self, Tokens tokens) except -1:
|
||||||
cdef Beam beam = Beam(self.moves.n_moves, self.cfg.beam_width)
|
cdef Beam beam = Beam(self.moves.n_moves, self.cfg.beam_width)
|
||||||
|
@ -123,8 +125,9 @@ cdef class Parser:
|
||||||
while not beam.is_done:
|
while not beam.is_done:
|
||||||
self._advance_beam(beam, None, False)
|
self._advance_beam(beam, None, False)
|
||||||
state = <State*>beam.at(0)
|
state = <State*>beam.at(0)
|
||||||
self.moves.finalize_state(state)
|
#self.moves.finalize_state(state)
|
||||||
tokens.set_parse(state.sent)
|
#tokens.set_parse(state.sent)
|
||||||
|
raise Exception
|
||||||
|
|
||||||
def _greedy_train(self, Tokens tokens, GoldParse gold):
|
def _greedy_train(self, Tokens tokens, GoldParse gold):
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
|
@ -137,17 +140,18 @@ cdef class Parser:
|
||||||
cdef Transition guess
|
cdef Transition guess
|
||||||
cdef Transition best
|
cdef Transition best
|
||||||
cdef StateClass stcls = StateClass(state.sent_len)
|
cdef StateClass stcls = StateClass(state.sent_len)
|
||||||
|
stcls.from_struct(state)
|
||||||
cdef atom_t[CONTEXT_SIZE] context
|
cdef atom_t[CONTEXT_SIZE] context
|
||||||
loss = 0
|
loss = 0
|
||||||
while not is_final(state):
|
words = [w.orth_ for w in tokens]
|
||||||
stcls.from_struct(state)
|
while not stcls.is_final():
|
||||||
_new_fill_context(context, stcls)
|
_new_fill_context(context, stcls)
|
||||||
scores = self.model.score(context)
|
scores = self.model.score(context)
|
||||||
guess = self.moves.best_valid(scores, stcls)
|
guess = self.moves.best_valid(scores, stcls)
|
||||||
best = self.moves.best_gold(scores, stcls, gold)
|
best = self.moves.best_gold(scores, stcls, gold)
|
||||||
cost = guess.get_cost(stcls, &gold.c, guess.label)
|
cost = guess.get_cost(stcls, &gold.c, guess.label)
|
||||||
self.model.update(context, guess.clas, best.clas, cost)
|
self.model.update(context, guess.clas, best.clas, cost)
|
||||||
guess.do(state, guess.label)
|
guess.do(stcls, guess.label)
|
||||||
loss += cost
|
loss += cost
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
@ -203,14 +207,16 @@ cdef class Parser:
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
cdef State* state = new_state(mem, tokens.data, tokens.length)
|
cdef State* state = new_state(mem, tokens.data, tokens.length)
|
||||||
self.moves.initialize_state(state)
|
self.moves.initialize_state(state)
|
||||||
|
cdef StateClass stcls = StateClass(state.sent_len)
|
||||||
|
stcls.from_struct(state)
|
||||||
|
|
||||||
cdef class_t clas
|
cdef class_t clas
|
||||||
cdef int n_feats
|
cdef int n_feats
|
||||||
for clas in hist:
|
for clas in hist:
|
||||||
fill_context(context, state)
|
_new_fill_context(context, stcls)
|
||||||
feats = self.model._extractor.get_feats(context, &n_feats)
|
feats = self.model._extractor.get_feats(context, &n_feats)
|
||||||
count_feats(counts[clas], feats, n_feats, inc)
|
count_feats(counts[clas], feats, n_feats, inc)
|
||||||
self.moves.c[clas].do(state, self.moves.c[clas].label)
|
self.moves.c[clas].do(stcls, self.moves.c[clas].label)
|
||||||
|
|
||||||
|
|
||||||
# These are passed as callbacks to thinc.search.Beam
|
# These are passed as callbacks to thinc.search.Beam
|
||||||
|
@ -220,7 +226,8 @@ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves)
|
||||||
src = <const State*>_src
|
src = <const State*>_src
|
||||||
moves = <const Transition*>_moves
|
moves = <const Transition*>_moves
|
||||||
copy_state(dest, src)
|
copy_state(dest, src)
|
||||||
moves[clas].do(dest, moves[clas].label)
|
raise Exception
|
||||||
|
#moves[clas].do(dest, moves[clas].label)
|
||||||
|
|
||||||
|
|
||||||
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
|
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
|
||||||
|
|
|
@ -126,7 +126,7 @@ cdef class StateClass:
|
||||||
return self._b_i >= self.length
|
return self._b_i >= self.length
|
||||||
|
|
||||||
cdef bint is_final(self) nogil:
|
cdef bint is_final(self) nogil:
|
||||||
return self.eol() and self.empty()
|
return self.eol() and self.stack_depth() <= 1
|
||||||
|
|
||||||
cdef bint has_head(self, int i) nogil:
|
cdef bint has_head(self, int i) nogil:
|
||||||
return self.safe_get(i).head != 0
|
return self.safe_get(i).head != 0
|
||||||
|
@ -196,7 +196,7 @@ cdef class StateClass:
|
||||||
self._sent[i].ent_type = ent_type
|
self._sent[i].ent_type = ent_type
|
||||||
|
|
||||||
cdef void set_sent_end(self, int i) nogil:
|
cdef void set_sent_end(self, int i) nogil:
|
||||||
if 0 < i < self.length:
|
if 0 <= i < self.length:
|
||||||
self._sent[i].sent_end = True
|
self._sent[i].sent_end = True
|
||||||
|
|
||||||
cdef void clone(self, StateClass src) nogil:
|
cdef void clone(self, StateClass src) nogil:
|
||||||
|
@ -208,6 +208,17 @@ cdef class StateClass:
|
||||||
self._s_i = src._s_i
|
self._s_i = src._s_i
|
||||||
self._e_i = src._e_i
|
self._e_i = src._e_i
|
||||||
|
|
||||||
|
def print_state(self, words):
|
||||||
|
words = list(words) + ['_']
|
||||||
|
top = words[self.S(0)] + '_%d' % self.H(self.S(0))
|
||||||
|
second = words[self.S(1)] + '_%d' % self.H(self.S(1))
|
||||||
|
third = words[self.S(2)] + '_%d' % self.H(self.S(2))
|
||||||
|
n0 = words[self.B(0)]
|
||||||
|
n1 = words[self.B(1)]
|
||||||
|
return ' '.join((str(self.stack_depth()), third, second, top, '|', n0, n1))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# From https://en.wikipedia.org/wiki/Hamming_weight
|
# From https://en.wikipedia.org/wiki/Hamming_weight
|
||||||
cdef inline uint32_t _popcount(uint32_t x) nogil:
|
cdef inline uint32_t _popcount(uint32_t x) nogil:
|
||||||
|
|
|
@ -19,14 +19,14 @@ cdef struct Transition:
|
||||||
|
|
||||||
bint (*is_valid)(StateClass state, int label) except -1
|
bint (*is_valid)(StateClass state, int label) except -1
|
||||||
int (*get_cost)(StateClass state, const GoldParseC* gold, int label) except -1
|
int (*get_cost)(StateClass state, const GoldParseC* gold, int label) except -1
|
||||||
int (*do)(State* state, int label) except -1
|
int (*do)(StateClass state, int label) except -1
|
||||||
|
|
||||||
|
|
||||||
ctypedef int (*get_cost_func_t)(StateClass state, const GoldParseC* gold, int label) except -1
|
ctypedef int (*get_cost_func_t)(StateClass state, const GoldParseC* gold, int label) except -1
|
||||||
ctypedef int (*move_cost_func_t)(StateClass state, const GoldParseC* gold) except -1
|
ctypedef int (*move_cost_func_t)(StateClass state, const GoldParseC* gold) except -1
|
||||||
ctypedef int (*label_cost_func_t)(StateClass state, const GoldParseC* gold, int label) except -1
|
ctypedef int (*label_cost_func_t)(StateClass state, const GoldParseC* gold, int label) except -1
|
||||||
|
|
||||||
ctypedef int (*do_func_t)(State* state, int label) except -1
|
ctypedef int (*do_func_t)(StateClass state, int label) except -1
|
||||||
|
|
||||||
|
|
||||||
cdef class TransitionSystem:
|
cdef class TransitionSystem:
|
||||||
|
@ -37,7 +37,7 @@ cdef class TransitionSystem:
|
||||||
cdef readonly int n_moves
|
cdef readonly int n_moves
|
||||||
|
|
||||||
cdef int initialize_state(self, State* state) except -1
|
cdef int initialize_state(self, State* state) except -1
|
||||||
cdef int finalize_state(self, State* state) except -1
|
cdef int finalize_state(self, StateClass state) except -1
|
||||||
|
|
||||||
cdef int preprocess_gold(self, GoldParse gold) except -1
|
cdef int preprocess_gold(self, GoldParse gold) except -1
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ cdef class TransitionSystem:
|
||||||
cdef int initialize_state(self, State* state) except -1:
|
cdef int initialize_state(self, State* state) except -1:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
cdef int finalize_state(self, State* state) except -1:
|
cdef int finalize_state(self, StateClass state) except -1:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
cdef int preprocess_gold(self, GoldParse gold) except -1:
|
cdef int preprocess_gold(self, GoldParse gold) except -1:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user