mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 18:56:36 +03:00
Readd beam search after refactor
This commit is contained in:
parent
36b2c9bdd5
commit
5a0f26be0c
1
setup.py
1
setup.py
|
@ -31,6 +31,7 @@ MOD_NAMES = [
|
|||
'spacy.tokenizer',
|
||||
'spacy.syntax.nn_parser',
|
||||
'spacy.syntax._parser_model',
|
||||
'spacy.syntax._beam_utils',
|
||||
'spacy.syntax.nonproj',
|
||||
'spacy.syntax.transition_system',
|
||||
'spacy.syntax.arc_eager',
|
||||
|
|
6
spacy/syntax/_beam_utils.pxd
Normal file
6
spacy/syntax/_beam_utils.pxd
Normal file
|
@ -0,0 +1,6 @@
|
|||
from thinc.typedefs cimport class_t
|
||||
|
||||
# These are passed as callbacks to thinc.search.Beam
|
||||
cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1
|
||||
|
||||
cdef int check_final_state(void* _state, void* extra_args) except -1
|
|
@ -15,7 +15,7 @@ from .stateclass cimport StateC, StateClass
|
|||
|
||||
|
||||
# These are passed as callbacks to thinc.search.Beam
|
||||
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
|
||||
cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
|
||||
dest = <StateC*>_dest
|
||||
src = <StateC*>_src
|
||||
moves = <const Transition*>_moves
|
||||
|
@ -24,12 +24,12 @@ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves)
|
|||
dest.push_hist(clas)
|
||||
|
||||
|
||||
cdef int _check_final_state(void* _state, void* extra_args) except -1:
|
||||
cdef int check_final_state(void* _state, void* extra_args) except -1:
|
||||
state = <StateC*>_state
|
||||
return state.is_final()
|
||||
|
||||
|
||||
cdef hash_t _hash_state(void* _state, void* _) except 0:
|
||||
cdef hash_t hash_state(void* _state, void* _) except 0:
|
||||
state = <StateC*>_state
|
||||
if state.is_final():
|
||||
return 1
|
||||
|
@ -37,6 +37,20 @@ cdef hash_t _hash_state(void* _state, void* _) except 0:
|
|||
return state.hash()
|
||||
|
||||
|
||||
def collect_states(beams):
|
||||
cdef StateClass state
|
||||
cdef Beam beam
|
||||
states = []
|
||||
for state_or_beam in beams:
|
||||
if isinstance(state_or_beam, StateClass):
|
||||
states.append(state_or_beam)
|
||||
else:
|
||||
beam = state_or_beam
|
||||
state = StateClass.borrow(<StateC*>beam.at(0))
|
||||
states.append(state)
|
||||
return states
|
||||
|
||||
|
||||
cdef class ParserBeam(object):
|
||||
cdef public TransitionSystem moves
|
||||
cdef public object states
|
||||
|
@ -82,8 +96,8 @@ cdef class ParserBeam(object):
|
|||
self._set_scores(beam, scores[i])
|
||||
if self.golds is not None:
|
||||
self._set_costs(beam, self.golds[i], follow_gold=follow_gold)
|
||||
beam.advance(_transition_state, NULL, <void*>self.moves.c)
|
||||
beam.check_done(_check_final_state, NULL)
|
||||
beam.advance(transition_state, NULL, <void*>self.moves.c)
|
||||
beam.check_done(check_final_state, NULL)
|
||||
# This handles the non-monotonic stuff for the parser.
|
||||
if beam.is_done and self.golds is not None:
|
||||
for j in range(beam.size):
|
||||
|
@ -144,15 +158,12 @@ nr_update = 0
|
|||
def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
|
||||
states, golds,
|
||||
state2vec, vec2scores,
|
||||
int width, float density, int hist_feats,
|
||||
losses=None, drop=0.):
|
||||
int width, losses=None, drop=0.):
|
||||
global nr_update
|
||||
cdef MaxViolation violn
|
||||
nr_update += 1
|
||||
pbeam = ParserBeam(moves, states, golds,
|
||||
width=width, density=density)
|
||||
gbeam = ParserBeam(moves, states, golds,
|
||||
width=width, density=density)
|
||||
pbeam = ParserBeam(moves, states, golds, width=width)
|
||||
gbeam = ParserBeam(moves, states, golds, width=width)
|
||||
cdef StateClass state
|
||||
beam_maps = []
|
||||
backprops = []
|
||||
|
@ -177,12 +188,6 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
|
|||
# Now that we have our flat list of states, feed them through the model
|
||||
token_ids = get_token_ids(states, nr_feature)
|
||||
vectors, bp_vectors = state2vec.begin_update(token_ids, drop=drop)
|
||||
if hist_feats:
|
||||
hists = numpy.asarray([st.history[:hist_feats] for st in states],
|
||||
dtype='i')
|
||||
scores, bp_scores = vec2scores.begin_update((vectors, hists),
|
||||
drop=drop)
|
||||
else:
|
||||
scores, bp_scores = vec2scores.begin_update(vectors, drop=drop)
|
||||
|
||||
# Store the callbacks for the backward pass
|
||||
|
@ -291,3 +296,27 @@ def get_gradient(nr_class, beam_maps, histories, losses):
|
|||
grads[j][i, clas] += loss
|
||||
key = key + tuple([clas])
|
||||
return grads
|
||||
|
||||
|
||||
def cleanup_beam(Beam beam):
|
||||
cdef StateC* state
|
||||
# Once parsing has finished, states in beam may not be unique. Is this
|
||||
# correct?
|
||||
seen = set()
|
||||
for i in range(beam.width):
|
||||
addr = <size_t>beam._parents[i].content
|
||||
if addr not in seen:
|
||||
state = <StateC*>addr
|
||||
del state
|
||||
seen.add(addr)
|
||||
else:
|
||||
raise ValueError(Errors.E023.format(addr=addr, i=i))
|
||||
addr = <size_t>beam._states[i].content
|
||||
if addr not in seen:
|
||||
state = <StateC*>addr
|
||||
del state
|
||||
seen.add(addr)
|
||||
else:
|
||||
raise ValueError(Errors.E023.format(addr=addr, i=i))
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user