mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
* Bug fixes to beam parser. Search still broken on non-gold sentences
This commit is contained in:
parent
1ec4e6fc95
commit
6e2564239d
|
@ -83,7 +83,7 @@ cdef class Parser:
|
|||
def __call__(self, Tokens tokens):
|
||||
if tokens.length == 0:
|
||||
return 0
|
||||
if self.cfg.beam_width <= 1:
|
||||
if self.cfg.get('beam_width', 1) <= 1:
|
||||
self._greedy_parse(tokens)
|
||||
else:
|
||||
self._beam_parse(tokens)
|
||||
|
@ -113,6 +113,7 @@ cdef class Parser:
|
|||
cdef int _beam_parse(self, Tokens tokens) except -1:
|
||||
cdef Beam beam = Beam(self.moves.n_moves, self.cfg.beam_width)
|
||||
beam.initialize(_init_state, tokens.length, tokens.data)
|
||||
beam.check_done(_check_final_state, NULL)
|
||||
while not beam.is_done:
|
||||
self._advance_beam(beam, None, False)
|
||||
state = <State*>beam.at(0)
|
||||
|
@ -145,8 +146,10 @@ cdef class Parser:
|
|||
def _beam_train(self, Tokens tokens, GoldParse gold_parse):
|
||||
cdef Beam pred = Beam(self.moves.n_moves, self.cfg.beam_width)
|
||||
pred.initialize(_init_state, tokens.length, tokens.data)
|
||||
pred.check_done(_check_final_state, NULL)
|
||||
cdef Beam gold = Beam(self.moves.n_moves, self.cfg.beam_width)
|
||||
gold.initialize(_init_state, tokens.length, tokens.data)
|
||||
gold.check_done(_check_final_state, NULL)
|
||||
|
||||
violn = MaxViolation()
|
||||
while not pred.is_done and not gold.is_done:
|
||||
|
@ -170,9 +173,10 @@ cdef class Parser:
|
|||
cdef const Transition* move
|
||||
for i in range(beam.size):
|
||||
state = <State*>beam.at(i)
|
||||
fill_context(context, state)
|
||||
self.model.set_scores(beam.scores[i], context)
|
||||
self.moves.set_valid(beam.is_valid[i], state)
|
||||
if not is_final(state):
|
||||
fill_context(context, state)
|
||||
self.model.set_scores(beam.scores[i], context)
|
||||
self.moves.set_valid(beam.is_valid[i], state)
|
||||
|
||||
if gold is not None:
|
||||
for i in range(beam.size):
|
||||
|
@ -194,8 +198,6 @@ cdef class Parser:
|
|||
cdef class_t clas
|
||||
cdef int n_feats
|
||||
for clas in hist:
|
||||
if is_final(state):
|
||||
break
|
||||
fill_context(context, state)
|
||||
feats = self.model._extractor.get_feats(context, &n_feats)
|
||||
count_feats(counts[clas], feats, n_feats, inc)
|
||||
|
|
Loading…
Reference in New Issue
Block a user