diff --git a/bin/parser/train.py b/bin/parser/train.py index 4c8611f2c..2184ffbe6 100755 --- a/bin/parser/train.py +++ b/bin/parser/train.py @@ -187,7 +187,7 @@ def get_labels(sents): def train(Language, train_loc, model_dir, n_iter=15, feat_set=u'basic', seed=0, - gold_preproc=False, force_gold=False, n_sents=0): + gold_preproc=False, n_sents=0): dep_model_dir = path.join(model_dir, 'deps') pos_model_dir = path.join(model_dir, 'pos') ner_model_dir = path.join(model_dir, 'ner') @@ -230,8 +230,8 @@ def train(Language, train_loc, model_dir, n_iter=15, feat_set=u'basic', seed=0, for tokens in sents: gold = GoldParse(tokens, annot_tuples) nlp.tagger(tokens) - nlp.parser.train(tokens, gold, force_gold=force_gold) - nlp.entity.train(tokens, gold, force_gold=force_gold) + nlp.parser.train(tokens, gold) + nlp.entity.train(tokens, gold) nlp.tagger.train(tokens, gold.tags) print '%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.ents_f, scorer.tags_acc) @@ -280,7 +280,7 @@ def write_parses(Language, dev_loc, model_dir, out_loc): def main(train_loc, dev_loc, model_dir, n_sents=0, out_loc="", verbose=False, debug=False): train(English, train_loc, model_dir, feat_set='basic' if not debug else 'debug', - gold_preproc=False, force_gold=False, n_sents=n_sents) + gold_preproc=False, n_sents=n_sents) if out_loc: write_parses(English, dev_loc, model_dir, out_loc) scorer = evaluate(English, dev_loc, model_dir, gold_preproc=False, verbose=verbose) diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index cc2c30143..9b71e78ed 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -93,7 +93,7 @@ cdef class GreedyParser: tokens.set_parse(state.sent) return 0 - def train(self, Tokens tokens, GoldParse gold, force_gold=False): + def train(self, Tokens tokens, GoldParse gold): self.moves.preprocess_gold(gold) cdef Pool mem = Pool() cdef State* state = new_state(mem, tokens.data, tokens.length) @@ -109,12 +109,11 @@ cdef class GreedyParser: while not is_final(state): fill_context(context, state) scores = self.model.score(context) + guess = self.moves.best_valid(scores, state) best = self.moves.best_gold(scores, state, gold) + cost = guess.get_cost(&guess, state, gold) - self.model.update(context, guess.clas, best.clas, cost) - if force_gold: - best.do(&best, state) - else: - guess.do(&guess, state) + + guess.do(&guess, state)