* Fix bugs in new greedy/beam parser

This commit is contained in:
Matthew Honnibal 2015-06-02 02:01:33 +02:00
parent 66dfa95847
commit e822df0867
2 changed files with 14 additions and 14 deletions

View File

@ -140,7 +140,7 @@ def _tag_partition(nlp, docs, gold_preproc=False):
def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
seed=0, gold_preproc=False, n_sents=0, corruption_level=0,
train_tags=None):
train_tags=None, beam_width=1):
dep_model_dir = path.join(model_dir, 'deps')
pos_model_dir = path.join(model_dir, 'pos')
ner_model_dir = path.join(model_dir, 'ner')
@ -158,9 +158,10 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
labels=Language.ParserTransitionSystem.get_labels(gold_tuples),
beam_width=16)
beam_width=beam_width)
Config.write(ner_model_dir, 'config', features='ner', seed=seed,
labels=Language.EntityTransitionSystem.get_labels(gold_tuples))
labels=Language.EntityTransitionSystem.get_labels(gold_tuples),
beam_width=1)
if n_sents > 0:
gold_tuples = gold_tuples[:n_sents]
@ -188,8 +189,7 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
else:
nlp.tagger(tokens)
gold = GoldParse(tokens, annot_tuples, make_projective=True)
if gold.is_projective:
loss += nlp.parser.train(tokens, gold)
loss += nlp.parser.train(tokens, gold)
nlp.entity.train(tokens, gold)
nlp.tagger.train(tokens, gold.tags)

View File

@ -109,7 +109,7 @@ cdef class Parser:
tokens.set_parse(state.sent)
cdef int _beam_parse(self, Tokens tokens) except -1:
cdef Beam beam = Beam(self.model.n_classes, self.cfg.beam_width)
cdef Beam beam = Beam(self.moves.n_moves, self.cfg.beam_width)
beam.initialize(_init_state, tokens.length, tokens.data)
while not beam.is_done:
self._advance_beam(beam, None, False)
@ -141,9 +141,9 @@ cdef class Parser:
return loss
def _beam_train(self, Tokens tokens, GoldParse gold_parse):
cdef Beam pred = Beam(self.model.n_classes, self.cfg.beam_width)
cdef Beam pred = Beam(self.moves.n_moves, self.cfg.beam_width)
pred.initialize(_init_state, tokens.length, tokens.data)
cdef Beam gold = Beam(self.model.n_classes, self.cfg.beam_width)
cdef Beam gold = Beam(self.moves.n_moves, self.cfg.beam_width)
gold.initialize(_init_state, tokens.length, tokens.data)
violn = MaxViolation()
@ -170,18 +170,18 @@ cdef class Parser:
scores = self.model.score(context)
validities = self.moves.get_valid(state)
if gold is None:
for j in range(self.model.n_clases):
beam.set_cell(i, j, scores[j], 0, validities[j])
for j in range(self.moves.n_moves):
beam.set_cell(i, j, scores[j], validities[j], 0)
elif not follow_gold:
for j in range(self.model.n_classes):
for j in range(self.moves.n_moves):
move = &self.moves.c[j]
cost = move.get_cost(move, state, gold)
beam.set_cell(i, j, scores[j], cost, validities[j])
beam.set_cell(i, j, scores[j], validities[j], cost)
else:
for j in range(self.model.n_classes):
for j in range(self.moves.n_moves):
move = &self.moves.c[j]
cost = move.get_cost(move, state, gold)
beam.set_cell(i, j, scores[j], cost, cost == 0)
beam.set_cell(i, j, scores[j], cost == 0, cost)
beam.advance(_transition_state, <void*>self.moves.c)
beam.check_done(_check_final_state, NULL)