mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 09:44:36 +03:00
* Remove State* from parser.pyx entirely, switching over to StateClass. Beam parsing still untested.
This commit is contained in:
parent
f14a1526aa
commit
6a94b64eca
|
@ -35,7 +35,6 @@ from ..strings cimport StringStore
|
||||||
from .arc_eager cimport TransitionSystem, Transition
|
from .arc_eager cimport TransitionSystem, Transition
|
||||||
from .transition_system import OracleError
|
from .transition_system import OracleError
|
||||||
|
|
||||||
from ._state cimport State, new_state, copy_state, is_final, push_stack, get_left, get_n0
|
|
||||||
from ..gold cimport GoldParse
|
from ..gold cimport GoldParse
|
||||||
|
|
||||||
from . import _parse_features
|
from . import _parse_features
|
||||||
|
@ -43,6 +42,7 @@ from ._parse_features cimport CONTEXT_SIZE
|
||||||
from ._parse_features cimport fill_context
|
from ._parse_features cimport fill_context
|
||||||
from .stateclass cimport StateClass
|
from .stateclass cimport StateClass
|
||||||
|
|
||||||
|
from cpython.ref cimport PyObject
|
||||||
|
|
||||||
DEBUG = False
|
DEBUG = False
|
||||||
def set_debug(val):
|
def set_debug(val):
|
||||||
|
@ -50,20 +50,6 @@ def set_debug(val):
|
||||||
DEBUG = val
|
DEBUG = val
|
||||||
|
|
||||||
|
|
||||||
cdef unicode print_state(State* s, list words):
|
|
||||||
words = list(words) + ['EOL']
|
|
||||||
top = words[s.stack[0]] + '_%d' % s.sent[s.stack[0]].head
|
|
||||||
second = words[s.stack[-1]] + '_%d' % s.sent[s.stack[-1]].head
|
|
||||||
third = words[s.stack[-2]] + '_%d' % s.sent[s.stack[-2]].head
|
|
||||||
n0 = words[s.i] if s.i < len(words) else 'EOL'
|
|
||||||
n1 = words[s.i + 1] if s.i+1 < len(words) else 'EOL'
|
|
||||||
if s.ents_len:
|
|
||||||
ent = '%s %d-%d' % (s.ent.label, s.ent.start, s.ent.end)
|
|
||||||
else:
|
|
||||||
ent = '-'
|
|
||||||
return ' '.join((ent, str(s.stack_len), third, second, top, '|', n0, n1))
|
|
||||||
|
|
||||||
|
|
||||||
def get_templates(name):
|
def get_templates(name):
|
||||||
pf = _parse_features
|
pf = _parse_features
|
||||||
if name == 'ner':
|
if name == 'ner':
|
||||||
|
@ -102,10 +88,8 @@ cdef class Parser:
|
||||||
cdef atom_t[CONTEXT_SIZE] context
|
cdef atom_t[CONTEXT_SIZE] context
|
||||||
cdef int n_feats
|
cdef int n_feats
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
cdef State* state = new_state(mem, tokens.data, tokens.length)
|
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
|
||||||
self.moves.initialize_state(state)
|
self.moves.initialize_state(stcls)
|
||||||
cdef StateClass stcls = StateClass(state.sent_len)
|
|
||||||
stcls.from_struct(state)
|
|
||||||
cdef Transition guess
|
cdef Transition guess
|
||||||
words = [w.orth_ for w in tokens]
|
words = [w.orth_ for w in tokens]
|
||||||
while not stcls.is_final():
|
while not stcls.is_final():
|
||||||
|
@ -123,23 +107,21 @@ cdef class Parser:
|
||||||
beam.check_done(_check_final_state, NULL)
|
beam.check_done(_check_final_state, NULL)
|
||||||
while not beam.is_done:
|
while not beam.is_done:
|
||||||
self._advance_beam(beam, None, False)
|
self._advance_beam(beam, None, False)
|
||||||
state = <State*>beam.at(0)
|
state = <StateClass>beam.at(0)
|
||||||
#self.moves.finalize_state(state)
|
#self.moves.finalize_state(state)
|
||||||
#tokens.set_parse(state.sent)
|
#tokens.set_parse(state.sent)
|
||||||
raise Exception
|
raise Exception
|
||||||
|
|
||||||
def _greedy_train(self, Tokens tokens, GoldParse gold):
|
def _greedy_train(self, Tokens tokens, GoldParse gold):
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
cdef State* state = new_state(mem, tokens.data, tokens.length)
|
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
|
||||||
self.moves.initialize_state(state)
|
self.moves.initialize_state(stcls)
|
||||||
|
|
||||||
cdef int cost
|
cdef int cost
|
||||||
cdef const Feature* feats
|
cdef const Feature* feats
|
||||||
cdef const weight_t* scores
|
cdef const weight_t* scores
|
||||||
cdef Transition guess
|
cdef Transition guess
|
||||||
cdef Transition best
|
cdef Transition best
|
||||||
cdef StateClass stcls = StateClass(state.sent_len)
|
|
||||||
stcls.from_struct(state)
|
|
||||||
cdef atom_t[CONTEXT_SIZE] context
|
cdef atom_t[CONTEXT_SIZE] context
|
||||||
loss = 0
|
loss = 0
|
||||||
words = [w.orth_ for w in tokens]
|
words = [w.orth_ for w in tokens]
|
||||||
|
@ -178,36 +160,32 @@ cdef class Parser:
|
||||||
|
|
||||||
def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold):
|
def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold):
|
||||||
cdef atom_t[CONTEXT_SIZE] context
|
cdef atom_t[CONTEXT_SIZE] context
|
||||||
cdef State* state
|
|
||||||
cdef int i, j, cost
|
cdef int i, j, cost
|
||||||
cdef bint is_valid
|
cdef bint is_valid
|
||||||
cdef const Transition* move
|
cdef const Transition* move
|
||||||
cdef StateClass stcls = StateClass(gold.length)
|
cdef StateClass stcls = StateClass(gold.length)
|
||||||
for i in range(beam.size):
|
for i in range(beam.size):
|
||||||
state = <State*>beam.at(i)
|
stcls = <StateClass>beam.at(i)
|
||||||
stcls.from_struct(state)
|
if not stcls.is_final():
|
||||||
if not is_final(state):
|
|
||||||
fill_context(context, stcls)
|
fill_context(context, stcls)
|
||||||
self.model.set_scores(beam.scores[i], context)
|
self.model.set_scores(beam.scores[i], context)
|
||||||
self.moves.set_valid(beam.is_valid[i], stcls)
|
self.moves.set_valid(beam.is_valid[i], stcls)
|
||||||
|
|
||||||
if gold is not None:
|
if gold is not None:
|
||||||
for i in range(beam.size):
|
for i in range(beam.size):
|
||||||
state = <State*>beam.at(i)
|
stcls = <StateClass>beam.at(i)
|
||||||
self.moves.set_costs(beam.costs[i], stcls, gold)
|
self.moves.set_costs(beam.costs[i], stcls, gold)
|
||||||
if follow_gold:
|
if follow_gold:
|
||||||
for j in range(self.moves.n_moves):
|
for j in range(self.moves.n_moves):
|
||||||
beam.is_valid[i][j] *= beam.costs[i][j] == 0
|
beam.is_valid[i][j] *= beam.costs[i][j] == 0
|
||||||
beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
|
beam.advance(_transition_state, NULL, <void*>self.moves.c)
|
||||||
beam.check_done(_check_final_state, NULL)
|
beam.check_done(_check_final_state, NULL)
|
||||||
|
|
||||||
def _count_feats(self, dict counts, Tokens tokens, list hist, int inc):
|
def _count_feats(self, dict counts, Tokens tokens, list hist, int inc):
|
||||||
cdef atom_t[CONTEXT_SIZE] context
|
cdef atom_t[CONTEXT_SIZE] context
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
cdef State* state = new_state(mem, tokens.data, tokens.length)
|
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
|
||||||
self.moves.initialize_state(state)
|
self.moves.initialize_state(stcls)
|
||||||
cdef StateClass stcls = StateClass(state.sent_len)
|
|
||||||
stcls.from_struct(state)
|
|
||||||
|
|
||||||
cdef class_t clas
|
cdef class_t clas
|
||||||
cdef int n_feats
|
cdef int n_feats
|
||||||
|
@ -221,24 +199,23 @@ cdef class Parser:
|
||||||
# These are passed as callbacks to thinc.search.Beam
|
# These are passed as callbacks to thinc.search.Beam
|
||||||
|
|
||||||
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
|
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
|
||||||
dest = <State*>_dest
|
dest = <StateClass>_dest
|
||||||
src = <const State*>_src
|
src = <StateClass>_src
|
||||||
moves = <const Transition*>_moves
|
moves = <const Transition*>_moves
|
||||||
copy_state(dest, src)
|
dest.clone(src)
|
||||||
raise Exception
|
moves[clas].do(dest, moves[clas].label)
|
||||||
#moves[clas].do(dest, moves[clas].label)
|
|
||||||
|
|
||||||
|
|
||||||
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
|
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
|
||||||
state = new_state(mem, <const TokenC*>tokens, length)
|
cdef StateClass st = StateClass.init(<const TokenC*>tokens, length)
|
||||||
push_stack(state)
|
return <void*>st
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
cdef int _check_final_state(void* state, void* extra_args) except -1:
|
cdef int _check_final_state(void* _state, void* extra_args) except -1:
|
||||||
return is_final(<State*>state)
|
return (<StateClass>_state).is_final()
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
cdef hash_t _hash_state(void* _state, void* _) except 0:
|
cdef hash_t _hash_state(void* _state, void* _) except 0:
|
||||||
state = <const State*>_state
|
state = <const State*>_state
|
||||||
cdef atom_t[10] rep
|
cdef atom_t[10] rep
|
||||||
|
@ -257,3 +234,4 @@ cdef hash_t _hash_state(void* _state, void* _) except 0:
|
||||||
rep[8] = 0
|
rep[8] = 0
|
||||||
rep[9] = state.sent[state.i].l_kids
|
rep[9] = state.sent[state.i].l_kids
|
||||||
return hash64(rep, sizeof(atom_t) * 10, 0)
|
return hash64(rep, sizeof(atom_t) * 10, 0)
|
||||||
|
"""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user