From 77f2b218f9fb5fe2460154f4e3747ebceb8be703 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 24 Feb 2016 18:19:38 +0100 Subject: [PATCH] * Update conll_train script --- bin/tagger/conll_train.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/bin/tagger/conll_train.py b/bin/tagger/conll_train.py index 447e3880b..97b3c3c99 100755 --- a/bin/tagger/conll_train.py +++ b/bin/tagger/conll_train.py @@ -87,15 +87,15 @@ def _parse_line(line): def score_model(nlp, gold_tuples, verbose=False): - scorer = Scorer() + correct = 0.0 + total = 0.0 for words, gold_tags in gold_tuples: tokens = nlp.tokenizer.tokens_from_list(words) nlp.tagger(tokens) for token, gold in zip(tokens, gold_tags): - scorer.tags.tp += token.tag_ == gold - scorer.tags.fp += token.tag_ != gold - scorer.tags.fn += token.tag_ != gold - return scorer.tags_acc + correct += token.tag_ == gold + total += 1 + return (correct / total) * 100 def train(Language, train_sents, dev_sents, model_dir, n_iter=15, seed=0, @@ -116,8 +116,6 @@ def train(Language, train_sents, dev_sents, model_dir, n_iter=15, seed=0, random.shuffle(train_sents) heldout_sents = train_sents[:int(nr_train * 0.1)] train_sents = train_sents[len(heldout_sents):] - #train_sents = train_sents[:500] - #assert len(heldout_sents) < len(train_sents) prev_score = 0.0 variance = 0.001 last_good_learn_rate = nlp.tagger.model.eta @@ -130,15 +128,15 @@ def train(Language, train_sents, dev_sents, model_dir, n_iter=15, seed=0, acc += nlp.tagger.train(tokens, gold_tags) total += len(tokens) n += 1 - if n and n % 10000 == 0: + if n and n % 20000 == 0: dev_score = score_model(nlp, heldout_sents) eval_score = score_model(nlp, dev_sents) - if dev_score > prev_score: + if dev_score >= prev_score: nlp.tagger.model.keep_update() prev_score = dev_score variance = 0.001 last_good_learn_rate = nlp.tagger.model.eta - nlp.tagger.model.eta *= 1.05 + nlp.tagger.model.eta *= 1.01 print('%d:\t%.3f\t%.3f\t%.3f\t%.4f' % (n, acc/total, dev_score, eval_score, nlp.tagger.model.eta)) else: nlp.tagger.model.backtrack()