mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 17:54:39 +03:00
* Fix bugs in new greedy/beam parser
This commit is contained in:
parent
66dfa95847
commit
e822df0867
|
@ -140,7 +140,7 @@ def _tag_partition(nlp, docs, gold_preproc=False):
|
||||||
|
|
||||||
def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
||||||
seed=0, gold_preproc=False, n_sents=0, corruption_level=0,
|
seed=0, gold_preproc=False, n_sents=0, corruption_level=0,
|
||||||
train_tags=None):
|
train_tags=None, beam_width=1):
|
||||||
dep_model_dir = path.join(model_dir, 'deps')
|
dep_model_dir = path.join(model_dir, 'deps')
|
||||||
pos_model_dir = path.join(model_dir, 'pos')
|
pos_model_dir = path.join(model_dir, 'pos')
|
||||||
ner_model_dir = path.join(model_dir, 'ner')
|
ner_model_dir = path.join(model_dir, 'ner')
|
||||||
|
@ -158,9 +158,10 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
||||||
|
|
||||||
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
||||||
labels=Language.ParserTransitionSystem.get_labels(gold_tuples),
|
labels=Language.ParserTransitionSystem.get_labels(gold_tuples),
|
||||||
beam_width=16)
|
beam_width=beam_width)
|
||||||
Config.write(ner_model_dir, 'config', features='ner', seed=seed,
|
Config.write(ner_model_dir, 'config', features='ner', seed=seed,
|
||||||
labels=Language.EntityTransitionSystem.get_labels(gold_tuples))
|
labels=Language.EntityTransitionSystem.get_labels(gold_tuples),
|
||||||
|
beam_width=1)
|
||||||
|
|
||||||
if n_sents > 0:
|
if n_sents > 0:
|
||||||
gold_tuples = gold_tuples[:n_sents]
|
gold_tuples = gold_tuples[:n_sents]
|
||||||
|
@ -188,8 +189,7 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
||||||
else:
|
else:
|
||||||
nlp.tagger(tokens)
|
nlp.tagger(tokens)
|
||||||
gold = GoldParse(tokens, annot_tuples, make_projective=True)
|
gold = GoldParse(tokens, annot_tuples, make_projective=True)
|
||||||
if gold.is_projective:
|
loss += nlp.parser.train(tokens, gold)
|
||||||
loss += nlp.parser.train(tokens, gold)
|
|
||||||
|
|
||||||
nlp.entity.train(tokens, gold)
|
nlp.entity.train(tokens, gold)
|
||||||
nlp.tagger.train(tokens, gold.tags)
|
nlp.tagger.train(tokens, gold.tags)
|
||||||
|
|
|
@ -109,7 +109,7 @@ cdef class Parser:
|
||||||
tokens.set_parse(state.sent)
|
tokens.set_parse(state.sent)
|
||||||
|
|
||||||
cdef int _beam_parse(self, Tokens tokens) except -1:
|
cdef int _beam_parse(self, Tokens tokens) except -1:
|
||||||
cdef Beam beam = Beam(self.model.n_classes, self.cfg.beam_width)
|
cdef Beam beam = Beam(self.moves.n_moves, self.cfg.beam_width)
|
||||||
beam.initialize(_init_state, tokens.length, tokens.data)
|
beam.initialize(_init_state, tokens.length, tokens.data)
|
||||||
while not beam.is_done:
|
while not beam.is_done:
|
||||||
self._advance_beam(beam, None, False)
|
self._advance_beam(beam, None, False)
|
||||||
|
@ -141,9 +141,9 @@ cdef class Parser:
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def _beam_train(self, Tokens tokens, GoldParse gold_parse):
|
def _beam_train(self, Tokens tokens, GoldParse gold_parse):
|
||||||
cdef Beam pred = Beam(self.model.n_classes, self.cfg.beam_width)
|
cdef Beam pred = Beam(self.moves.n_moves, self.cfg.beam_width)
|
||||||
pred.initialize(_init_state, tokens.length, tokens.data)
|
pred.initialize(_init_state, tokens.length, tokens.data)
|
||||||
cdef Beam gold = Beam(self.model.n_classes, self.cfg.beam_width)
|
cdef Beam gold = Beam(self.moves.n_moves, self.cfg.beam_width)
|
||||||
gold.initialize(_init_state, tokens.length, tokens.data)
|
gold.initialize(_init_state, tokens.length, tokens.data)
|
||||||
|
|
||||||
violn = MaxViolation()
|
violn = MaxViolation()
|
||||||
|
@ -170,18 +170,18 @@ cdef class Parser:
|
||||||
scores = self.model.score(context)
|
scores = self.model.score(context)
|
||||||
validities = self.moves.get_valid(state)
|
validities = self.moves.get_valid(state)
|
||||||
if gold is None:
|
if gold is None:
|
||||||
for j in range(self.model.n_clases):
|
for j in range(self.moves.n_moves):
|
||||||
beam.set_cell(i, j, scores[j], 0, validities[j])
|
beam.set_cell(i, j, scores[j], validities[j], 0)
|
||||||
elif not follow_gold:
|
elif not follow_gold:
|
||||||
for j in range(self.model.n_classes):
|
for j in range(self.moves.n_moves):
|
||||||
move = &self.moves.c[j]
|
move = &self.moves.c[j]
|
||||||
cost = move.get_cost(move, state, gold)
|
cost = move.get_cost(move, state, gold)
|
||||||
beam.set_cell(i, j, scores[j], cost, validities[j])
|
beam.set_cell(i, j, scores[j], validities[j], cost)
|
||||||
else:
|
else:
|
||||||
for j in range(self.model.n_classes):
|
for j in range(self.moves.n_moves):
|
||||||
move = &self.moves.c[j]
|
move = &self.moves.c[j]
|
||||||
cost = move.get_cost(move, state, gold)
|
cost = move.get_cost(move, state, gold)
|
||||||
beam.set_cell(i, j, scores[j], cost, cost == 0)
|
beam.set_cell(i, j, scores[j], cost == 0, cost)
|
||||||
beam.advance(_transition_state, <void*>self.moves.c)
|
beam.advance(_transition_state, <void*>self.moves.c)
|
||||||
beam.check_done(_check_final_state, NULL)
|
beam.check_done(_check_final_state, NULL)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user