mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
* Remove State* from parser.pyx entirely, switching over to StateClass. Beam parsing still untested.
This commit is contained in:
parent
f14a1526aa
commit
6a94b64eca
|
@ -35,7 +35,6 @@ from ..strings cimport StringStore
|
|||
from .arc_eager cimport TransitionSystem, Transition
|
||||
from .transition_system import OracleError
|
||||
|
||||
from ._state cimport State, new_state, copy_state, is_final, push_stack, get_left, get_n0
|
||||
from ..gold cimport GoldParse
|
||||
|
||||
from . import _parse_features
|
||||
|
@ -43,6 +42,7 @@ from ._parse_features cimport CONTEXT_SIZE
|
|||
from ._parse_features cimport fill_context
|
||||
from .stateclass cimport StateClass
|
||||
|
||||
from cpython.ref cimport PyObject
|
||||
|
||||
DEBUG = False
|
||||
def set_debug(val):
|
||||
|
@ -50,20 +50,6 @@ def set_debug(val):
|
|||
DEBUG = val
|
||||
|
||||
|
||||
cdef unicode print_state(State* s, list words):
|
||||
words = list(words) + ['EOL']
|
||||
top = words[s.stack[0]] + '_%d' % s.sent[s.stack[0]].head
|
||||
second = words[s.stack[-1]] + '_%d' % s.sent[s.stack[-1]].head
|
||||
third = words[s.stack[-2]] + '_%d' % s.sent[s.stack[-2]].head
|
||||
n0 = words[s.i] if s.i < len(words) else 'EOL'
|
||||
n1 = words[s.i + 1] if s.i+1 < len(words) else 'EOL'
|
||||
if s.ents_len:
|
||||
ent = '%s %d-%d' % (s.ent.label, s.ent.start, s.ent.end)
|
||||
else:
|
||||
ent = '-'
|
||||
return ' '.join((ent, str(s.stack_len), third, second, top, '|', n0, n1))
|
||||
|
||||
|
||||
def get_templates(name):
|
||||
pf = _parse_features
|
||||
if name == 'ner':
|
||||
|
@ -102,10 +88,8 @@ cdef class Parser:
|
|||
cdef atom_t[CONTEXT_SIZE] context
|
||||
cdef int n_feats
|
||||
cdef Pool mem = Pool()
|
||||
cdef State* state = new_state(mem, tokens.data, tokens.length)
|
||||
self.moves.initialize_state(state)
|
||||
cdef StateClass stcls = StateClass(state.sent_len)
|
||||
stcls.from_struct(state)
|
||||
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
|
||||
self.moves.initialize_state(stcls)
|
||||
cdef Transition guess
|
||||
words = [w.orth_ for w in tokens]
|
||||
while not stcls.is_final():
|
||||
|
@ -123,23 +107,21 @@ cdef class Parser:
|
|||
beam.check_done(_check_final_state, NULL)
|
||||
while not beam.is_done:
|
||||
self._advance_beam(beam, None, False)
|
||||
state = <State*>beam.at(0)
|
||||
state = <StateClass>beam.at(0)
|
||||
#self.moves.finalize_state(state)
|
||||
#tokens.set_parse(state.sent)
|
||||
raise Exception
|
||||
|
||||
def _greedy_train(self, Tokens tokens, GoldParse gold):
|
||||
cdef Pool mem = Pool()
|
||||
cdef State* state = new_state(mem, tokens.data, tokens.length)
|
||||
self.moves.initialize_state(state)
|
||||
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
|
||||
self.moves.initialize_state(stcls)
|
||||
|
||||
cdef int cost
|
||||
cdef const Feature* feats
|
||||
cdef const weight_t* scores
|
||||
cdef Transition guess
|
||||
cdef Transition best
|
||||
cdef StateClass stcls = StateClass(state.sent_len)
|
||||
stcls.from_struct(state)
|
||||
cdef atom_t[CONTEXT_SIZE] context
|
||||
loss = 0
|
||||
words = [w.orth_ for w in tokens]
|
||||
|
@ -178,36 +160,32 @@ cdef class Parser:
|
|||
|
||||
def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold):
|
||||
cdef atom_t[CONTEXT_SIZE] context
|
||||
cdef State* state
|
||||
cdef int i, j, cost
|
||||
cdef bint is_valid
|
||||
cdef const Transition* move
|
||||
cdef StateClass stcls = StateClass(gold.length)
|
||||
for i in range(beam.size):
|
||||
state = <State*>beam.at(i)
|
||||
stcls.from_struct(state)
|
||||
if not is_final(state):
|
||||
stcls = <StateClass>beam.at(i)
|
||||
if not stcls.is_final():
|
||||
fill_context(context, stcls)
|
||||
self.model.set_scores(beam.scores[i], context)
|
||||
self.moves.set_valid(beam.is_valid[i], stcls)
|
||||
|
||||
if gold is not None:
|
||||
for i in range(beam.size):
|
||||
state = <State*>beam.at(i)
|
||||
stcls = <StateClass>beam.at(i)
|
||||
self.moves.set_costs(beam.costs[i], stcls, gold)
|
||||
if follow_gold:
|
||||
for j in range(self.moves.n_moves):
|
||||
beam.is_valid[i][j] *= beam.costs[i][j] == 0
|
||||
beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
|
||||
beam.advance(_transition_state, NULL, <void*>self.moves.c)
|
||||
beam.check_done(_check_final_state, NULL)
|
||||
|
||||
def _count_feats(self, dict counts, Tokens tokens, list hist, int inc):
|
||||
cdef atom_t[CONTEXT_SIZE] context
|
||||
cdef Pool mem = Pool()
|
||||
cdef State* state = new_state(mem, tokens.data, tokens.length)
|
||||
self.moves.initialize_state(state)
|
||||
cdef StateClass stcls = StateClass(state.sent_len)
|
||||
stcls.from_struct(state)
|
||||
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
|
||||
self.moves.initialize_state(stcls)
|
||||
|
||||
cdef class_t clas
|
||||
cdef int n_feats
|
||||
|
@ -221,24 +199,23 @@ cdef class Parser:
|
|||
# These are passed as callbacks to thinc.search.Beam
|
||||
|
||||
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
|
||||
dest = <State*>_dest
|
||||
src = <const State*>_src
|
||||
dest = <StateClass>_dest
|
||||
src = <StateClass>_src
|
||||
moves = <const Transition*>_moves
|
||||
copy_state(dest, src)
|
||||
raise Exception
|
||||
#moves[clas].do(dest, moves[clas].label)
|
||||
dest.clone(src)
|
||||
moves[clas].do(dest, moves[clas].label)
|
||||
|
||||
|
||||
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
|
||||
state = new_state(mem, <const TokenC*>tokens, length)
|
||||
push_stack(state)
|
||||
return state
|
||||
cdef StateClass st = StateClass.init(<const TokenC*>tokens, length)
|
||||
return <void*>st
|
||||
|
||||
|
||||
cdef int _check_final_state(void* state, void* extra_args) except -1:
|
||||
return is_final(<State*>state)
|
||||
cdef int _check_final_state(void* _state, void* extra_args) except -1:
|
||||
return (<StateClass>_state).is_final()
|
||||
|
||||
|
||||
"""
|
||||
cdef hash_t _hash_state(void* _state, void* _) except 0:
|
||||
state = <const State*>_state
|
||||
cdef atom_t[10] rep
|
||||
|
@ -257,3 +234,4 @@ cdef hash_t _hash_state(void* _state, void* _) except 0:
|
|||
rep[8] = 0
|
||||
rep[9] = state.sent[state.i].l_kids
|
||||
return hash64(rep, sizeof(atom_t) * 10, 0)
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue
Block a user