* Remove State* from parser.pyx entirely, switching over to StateClass. Beam parsing still untested.

This commit is contained in:
Matthew Honnibal 2015-06-10 02:03:38 +02:00
parent f14a1526aa
commit 6a94b64eca

View File

@ -35,7 +35,6 @@ from ..strings cimport StringStore
from .arc_eager cimport TransitionSystem, Transition
from .transition_system import OracleError
from ._state cimport State, new_state, copy_state, is_final, push_stack, get_left, get_n0
from ..gold cimport GoldParse
from . import _parse_features
@ -43,6 +42,7 @@ from ._parse_features cimport CONTEXT_SIZE
from ._parse_features cimport fill_context
from .stateclass cimport StateClass
from cpython.ref cimport PyObject
DEBUG = False
def set_debug(val):
@ -50,20 +50,6 @@ def set_debug(val):
DEBUG = val
cdef unicode print_state(State* s, list words):
words = list(words) + ['EOL']
top = words[s.stack[0]] + '_%d' % s.sent[s.stack[0]].head
second = words[s.stack[-1]] + '_%d' % s.sent[s.stack[-1]].head
third = words[s.stack[-2]] + '_%d' % s.sent[s.stack[-2]].head
n0 = words[s.i] if s.i < len(words) else 'EOL'
n1 = words[s.i + 1] if s.i+1 < len(words) else 'EOL'
if s.ents_len:
ent = '%s %d-%d' % (s.ent.label, s.ent.start, s.ent.end)
else:
ent = '-'
return ' '.join((ent, str(s.stack_len), third, second, top, '|', n0, n1))
def get_templates(name):
pf = _parse_features
if name == 'ner':
@ -102,10 +88,8 @@ cdef class Parser:
cdef atom_t[CONTEXT_SIZE] context
cdef int n_feats
cdef Pool mem = Pool()
cdef State* state = new_state(mem, tokens.data, tokens.length)
self.moves.initialize_state(state)
cdef StateClass stcls = StateClass(state.sent_len)
stcls.from_struct(state)
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
self.moves.initialize_state(stcls)
cdef Transition guess
words = [w.orth_ for w in tokens]
while not stcls.is_final():
@ -123,23 +107,21 @@ cdef class Parser:
beam.check_done(_check_final_state, NULL)
while not beam.is_done:
self._advance_beam(beam, None, False)
state = <State*>beam.at(0)
state = <StateClass>beam.at(0)
#self.moves.finalize_state(state)
#tokens.set_parse(state.sent)
raise Exception
def _greedy_train(self, Tokens tokens, GoldParse gold):
cdef Pool mem = Pool()
cdef State* state = new_state(mem, tokens.data, tokens.length)
self.moves.initialize_state(state)
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
self.moves.initialize_state(stcls)
cdef int cost
cdef const Feature* feats
cdef const weight_t* scores
cdef Transition guess
cdef Transition best
cdef StateClass stcls = StateClass(state.sent_len)
stcls.from_struct(state)
cdef atom_t[CONTEXT_SIZE] context
loss = 0
words = [w.orth_ for w in tokens]
@ -178,36 +160,32 @@ cdef class Parser:
def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold):
cdef atom_t[CONTEXT_SIZE] context
cdef State* state
cdef int i, j, cost
cdef bint is_valid
cdef const Transition* move
cdef StateClass stcls = StateClass(gold.length)
for i in range(beam.size):
state = <State*>beam.at(i)
stcls.from_struct(state)
if not is_final(state):
stcls = <StateClass>beam.at(i)
if not stcls.is_final():
fill_context(context, stcls)
self.model.set_scores(beam.scores[i], context)
self.moves.set_valid(beam.is_valid[i], stcls)
if gold is not None:
for i in range(beam.size):
state = <State*>beam.at(i)
stcls = <StateClass>beam.at(i)
self.moves.set_costs(beam.costs[i], stcls, gold)
if follow_gold:
for j in range(self.moves.n_moves):
beam.is_valid[i][j] *= beam.costs[i][j] == 0
beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
beam.advance(_transition_state, NULL, <void*>self.moves.c)
beam.check_done(_check_final_state, NULL)
def _count_feats(self, dict counts, Tokens tokens, list hist, int inc):
cdef atom_t[CONTEXT_SIZE] context
cdef Pool mem = Pool()
cdef State* state = new_state(mem, tokens.data, tokens.length)
self.moves.initialize_state(state)
cdef StateClass stcls = StateClass(state.sent_len)
stcls.from_struct(state)
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
self.moves.initialize_state(stcls)
cdef class_t clas
cdef int n_feats
@ -221,24 +199,23 @@ cdef class Parser:
# These are passed as callbacks to thinc.search.Beam
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
dest = <State*>_dest
src = <const State*>_src
dest = <StateClass>_dest
src = <StateClass>_src
moves = <const Transition*>_moves
copy_state(dest, src)
raise Exception
#moves[clas].do(dest, moves[clas].label)
dest.clone(src)
moves[clas].do(dest, moves[clas].label)
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
state = new_state(mem, <const TokenC*>tokens, length)
push_stack(state)
return state
cdef StateClass st = StateClass.init(<const TokenC*>tokens, length)
return <void*>st
cdef int _check_final_state(void* state, void* extra_args) except -1:
return is_final(<State*>state)
cdef int _check_final_state(void* _state, void* extra_args) except -1:
return (<StateClass>_state).is_final()
"""
cdef hash_t _hash_state(void* _state, void* _) except 0:
state = <const State*>_state
cdef atom_t[10] rep
@ -257,3 +234,4 @@ cdef hash_t _hash_state(void* _state, void* _) except 0:
rep[8] = 0
rep[9] = state.sent[state.i].l_kids
return hash64(rep, sizeof(atom_t) * 10, 0)
"""