mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 09:44:36 +03:00
* Bug fixes to beam parser. Search still broken on non-gold sentences
This commit is contained in:
parent
1ec4e6fc95
commit
6e2564239d
|
@ -83,7 +83,7 @@ cdef class Parser:
|
||||||
def __call__(self, Tokens tokens):
|
def __call__(self, Tokens tokens):
|
||||||
if tokens.length == 0:
|
if tokens.length == 0:
|
||||||
return 0
|
return 0
|
||||||
if self.cfg.beam_width <= 1:
|
if self.cfg.get('beam_width', 1) <= 1:
|
||||||
self._greedy_parse(tokens)
|
self._greedy_parse(tokens)
|
||||||
else:
|
else:
|
||||||
self._beam_parse(tokens)
|
self._beam_parse(tokens)
|
||||||
|
@ -113,6 +113,7 @@ cdef class Parser:
|
||||||
cdef int _beam_parse(self, Tokens tokens) except -1:
|
cdef int _beam_parse(self, Tokens tokens) except -1:
|
||||||
cdef Beam beam = Beam(self.moves.n_moves, self.cfg.beam_width)
|
cdef Beam beam = Beam(self.moves.n_moves, self.cfg.beam_width)
|
||||||
beam.initialize(_init_state, tokens.length, tokens.data)
|
beam.initialize(_init_state, tokens.length, tokens.data)
|
||||||
|
beam.check_done(_check_final_state, NULL)
|
||||||
while not beam.is_done:
|
while not beam.is_done:
|
||||||
self._advance_beam(beam, None, False)
|
self._advance_beam(beam, None, False)
|
||||||
state = <State*>beam.at(0)
|
state = <State*>beam.at(0)
|
||||||
|
@ -145,8 +146,10 @@ cdef class Parser:
|
||||||
def _beam_train(self, Tokens tokens, GoldParse gold_parse):
|
def _beam_train(self, Tokens tokens, GoldParse gold_parse):
|
||||||
cdef Beam pred = Beam(self.moves.n_moves, self.cfg.beam_width)
|
cdef Beam pred = Beam(self.moves.n_moves, self.cfg.beam_width)
|
||||||
pred.initialize(_init_state, tokens.length, tokens.data)
|
pred.initialize(_init_state, tokens.length, tokens.data)
|
||||||
|
pred.check_done(_check_final_state, NULL)
|
||||||
cdef Beam gold = Beam(self.moves.n_moves, self.cfg.beam_width)
|
cdef Beam gold = Beam(self.moves.n_moves, self.cfg.beam_width)
|
||||||
gold.initialize(_init_state, tokens.length, tokens.data)
|
gold.initialize(_init_state, tokens.length, tokens.data)
|
||||||
|
gold.check_done(_check_final_state, NULL)
|
||||||
|
|
||||||
violn = MaxViolation()
|
violn = MaxViolation()
|
||||||
while not pred.is_done and not gold.is_done:
|
while not pred.is_done and not gold.is_done:
|
||||||
|
@ -170,9 +173,10 @@ cdef class Parser:
|
||||||
cdef const Transition* move
|
cdef const Transition* move
|
||||||
for i in range(beam.size):
|
for i in range(beam.size):
|
||||||
state = <State*>beam.at(i)
|
state = <State*>beam.at(i)
|
||||||
fill_context(context, state)
|
if not is_final(state):
|
||||||
self.model.set_scores(beam.scores[i], context)
|
fill_context(context, state)
|
||||||
self.moves.set_valid(beam.is_valid[i], state)
|
self.model.set_scores(beam.scores[i], context)
|
||||||
|
self.moves.set_valid(beam.is_valid[i], state)
|
||||||
|
|
||||||
if gold is not None:
|
if gold is not None:
|
||||||
for i in range(beam.size):
|
for i in range(beam.size):
|
||||||
|
@ -194,8 +198,6 @@ cdef class Parser:
|
||||||
cdef class_t clas
|
cdef class_t clas
|
||||||
cdef int n_feats
|
cdef int n_feats
|
||||||
for clas in hist:
|
for clas in hist:
|
||||||
if is_final(state):
|
|
||||||
break
|
|
||||||
fill_context(context, state)
|
fill_context(context, state)
|
||||||
feats = self.model._extractor.get_feats(context, &n_feats)
|
feats = self.model._extractor.get_feats(context, &n_feats)
|
||||||
count_feats(counts[clas], feats, n_feats, inc)
|
count_feats(counts[clas], feats, n_feats, inc)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user