mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 18:56:36 +03:00
* Fix bugs in new greedy/beam parser
This commit is contained in:
parent
66dfa95847
commit
e822df0867
|
@ -140,7 +140,7 @@ def _tag_partition(nlp, docs, gold_preproc=False):
|
|||
|
||||
def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
||||
seed=0, gold_preproc=False, n_sents=0, corruption_level=0,
|
||||
train_tags=None):
|
||||
train_tags=None, beam_width=1):
|
||||
dep_model_dir = path.join(model_dir, 'deps')
|
||||
pos_model_dir = path.join(model_dir, 'pos')
|
||||
ner_model_dir = path.join(model_dir, 'ner')
|
||||
|
@ -158,9 +158,10 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
|||
|
||||
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
||||
labels=Language.ParserTransitionSystem.get_labels(gold_tuples),
|
||||
beam_width=16)
|
||||
beam_width=beam_width)
|
||||
Config.write(ner_model_dir, 'config', features='ner', seed=seed,
|
||||
labels=Language.EntityTransitionSystem.get_labels(gold_tuples))
|
||||
labels=Language.EntityTransitionSystem.get_labels(gold_tuples),
|
||||
beam_width=1)
|
||||
|
||||
if n_sents > 0:
|
||||
gold_tuples = gold_tuples[:n_sents]
|
||||
|
@ -188,8 +189,7 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
|||
else:
|
||||
nlp.tagger(tokens)
|
||||
gold = GoldParse(tokens, annot_tuples, make_projective=True)
|
||||
if gold.is_projective:
|
||||
loss += nlp.parser.train(tokens, gold)
|
||||
loss += nlp.parser.train(tokens, gold)
|
||||
|
||||
nlp.entity.train(tokens, gold)
|
||||
nlp.tagger.train(tokens, gold.tags)
|
||||
|
|
|
@ -109,7 +109,7 @@ cdef class Parser:
|
|||
tokens.set_parse(state.sent)
|
||||
|
||||
cdef int _beam_parse(self, Tokens tokens) except -1:
|
||||
cdef Beam beam = Beam(self.model.n_classes, self.cfg.beam_width)
|
||||
cdef Beam beam = Beam(self.moves.n_moves, self.cfg.beam_width)
|
||||
beam.initialize(_init_state, tokens.length, tokens.data)
|
||||
while not beam.is_done:
|
||||
self._advance_beam(beam, None, False)
|
||||
|
@ -141,9 +141,9 @@ cdef class Parser:
|
|||
return loss
|
||||
|
||||
def _beam_train(self, Tokens tokens, GoldParse gold_parse):
|
||||
cdef Beam pred = Beam(self.model.n_classes, self.cfg.beam_width)
|
||||
cdef Beam pred = Beam(self.moves.n_moves, self.cfg.beam_width)
|
||||
pred.initialize(_init_state, tokens.length, tokens.data)
|
||||
cdef Beam gold = Beam(self.model.n_classes, self.cfg.beam_width)
|
||||
cdef Beam gold = Beam(self.moves.n_moves, self.cfg.beam_width)
|
||||
gold.initialize(_init_state, tokens.length, tokens.data)
|
||||
|
||||
violn = MaxViolation()
|
||||
|
@ -170,18 +170,18 @@ cdef class Parser:
|
|||
scores = self.model.score(context)
|
||||
validities = self.moves.get_valid(state)
|
||||
if gold is None:
|
||||
for j in range(self.model.n_clases):
|
||||
beam.set_cell(i, j, scores[j], 0, validities[j])
|
||||
for j in range(self.moves.n_moves):
|
||||
beam.set_cell(i, j, scores[j], validities[j], 0)
|
||||
elif not follow_gold:
|
||||
for j in range(self.model.n_classes):
|
||||
for j in range(self.moves.n_moves):
|
||||
move = &self.moves.c[j]
|
||||
cost = move.get_cost(move, state, gold)
|
||||
beam.set_cell(i, j, scores[j], cost, validities[j])
|
||||
beam.set_cell(i, j, scores[j], validities[j], cost)
|
||||
else:
|
||||
for j in range(self.model.n_classes):
|
||||
for j in range(self.moves.n_moves):
|
||||
move = &self.moves.c[j]
|
||||
cost = move.get_cost(move, state, gold)
|
||||
beam.set_cell(i, j, scores[j], cost, cost == 0)
|
||||
beam.set_cell(i, j, scores[j], cost == 0, cost)
|
||||
beam.advance(_transition_state, <void*>self.moves.c)
|
||||
beam.check_done(_check_final_state, NULL)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user