diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index fab990ef9..93fdff043 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -35,7 +35,6 @@ from ..strings cimport StringStore from .arc_eager cimport TransitionSystem, Transition from .transition_system import OracleError -from ._state cimport State, new_state, copy_state, is_final, push_stack, get_left, get_n0 from ..gold cimport GoldParse from . import _parse_features @@ -43,6 +42,7 @@ from ._parse_features cimport CONTEXT_SIZE from ._parse_features cimport fill_context from .stateclass cimport StateClass +from cpython.ref cimport PyObject DEBUG = False def set_debug(val): @@ -50,20 +50,6 @@ def set_debug(val): DEBUG = val -cdef unicode print_state(State* s, list words): - words = list(words) + ['EOL'] - top = words[s.stack[0]] + '_%d' % s.sent[s.stack[0]].head - second = words[s.stack[-1]] + '_%d' % s.sent[s.stack[-1]].head - third = words[s.stack[-2]] + '_%d' % s.sent[s.stack[-2]].head - n0 = words[s.i] if s.i < len(words) else 'EOL' - n1 = words[s.i + 1] if s.i+1 < len(words) else 'EOL' - if s.ents_len: - ent = '%s %d-%d' % (s.ent.label, s.ent.start, s.ent.end) - else: - ent = '-' - return ' '.join((ent, str(s.stack_len), third, second, top, '|', n0, n1)) - - def get_templates(name): pf = _parse_features if name == 'ner': @@ -102,10 +88,8 @@ cdef class Parser: cdef atom_t[CONTEXT_SIZE] context cdef int n_feats cdef Pool mem = Pool() - cdef State* state = new_state(mem, tokens.data, tokens.length) - self.moves.initialize_state(state) - cdef StateClass stcls = StateClass(state.sent_len) - stcls.from_struct(state) + cdef StateClass stcls = StateClass.init(tokens.data, tokens.length) + self.moves.initialize_state(stcls) cdef Transition guess words = [w.orth_ for w in tokens] while not stcls.is_final(): @@ -123,23 +107,21 @@ cdef class Parser: beam.check_done(_check_final_state, NULL) while not beam.is_done: self._advance_beam(beam, None, False) - state = beam.at(0) + state = beam.at(0) #self.moves.finalize_state(state) #tokens.set_parse(state.sent) raise Exception def _greedy_train(self, Tokens tokens, GoldParse gold): cdef Pool mem = Pool() - cdef State* state = new_state(mem, tokens.data, tokens.length) - self.moves.initialize_state(state) + cdef StateClass stcls = StateClass.init(tokens.data, tokens.length) + self.moves.initialize_state(stcls) cdef int cost cdef const Feature* feats cdef const weight_t* scores cdef Transition guess cdef Transition best - cdef StateClass stcls = StateClass(state.sent_len) - stcls.from_struct(state) cdef atom_t[CONTEXT_SIZE] context loss = 0 words = [w.orth_ for w in tokens] @@ -178,36 +160,32 @@ cdef class Parser: def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold): cdef atom_t[CONTEXT_SIZE] context - cdef State* state cdef int i, j, cost cdef bint is_valid cdef const Transition* move cdef StateClass stcls = StateClass(gold.length) for i in range(beam.size): - state = beam.at(i) - stcls.from_struct(state) - if not is_final(state): + stcls = beam.at(i) + if not stcls.is_final(): fill_context(context, stcls) self.model.set_scores(beam.scores[i], context) self.moves.set_valid(beam.is_valid[i], stcls) if gold is not None: for i in range(beam.size): - state = beam.at(i) + stcls = beam.at(i) self.moves.set_costs(beam.costs[i], stcls, gold) if follow_gold: for j in range(self.moves.n_moves): beam.is_valid[i][j] *= beam.costs[i][j] == 0 - beam.advance(_transition_state, _hash_state, self.moves.c) + beam.advance(_transition_state, NULL, self.moves.c) beam.check_done(_check_final_state, NULL) def _count_feats(self, dict counts, Tokens tokens, list hist, int inc): cdef atom_t[CONTEXT_SIZE] context cdef Pool mem = Pool() - cdef State* state = new_state(mem, tokens.data, tokens.length) - self.moves.initialize_state(state) - cdef StateClass stcls = StateClass(state.sent_len) - stcls.from_struct(state) + cdef StateClass stcls = StateClass.init(tokens.data, tokens.length) + self.moves.initialize_state(stcls) cdef class_t clas cdef int n_feats @@ -221,24 +199,23 @@ cdef class Parser: # These are passed as callbacks to thinc.search.Beam cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1: - dest = _dest - src = _src + dest = _dest + src = _src moves = _moves - copy_state(dest, src) - raise Exception - #moves[clas].do(dest, moves[clas].label) + dest.clone(src) + moves[clas].do(dest, moves[clas].label) cdef void* _init_state(Pool mem, int length, void* tokens) except NULL: - state = new_state(mem, tokens, length) - push_stack(state) - return state + cdef StateClass st = StateClass.init(tokens, length) + return st -cdef int _check_final_state(void* state, void* extra_args) except -1: - return is_final(state) +cdef int _check_final_state(void* _state, void* extra_args) except -1: + return (_state).is_final() +""" cdef hash_t _hash_state(void* _state, void* _) except 0: state = _state cdef atom_t[10] rep @@ -257,3 +234,4 @@ cdef hash_t _hash_state(void* _state, void* _) except 0: rep[8] = 0 rep[9] = state.sent[state.i].l_kids return hash64(rep, sizeof(atom_t) * 10, 0) +"""