* Increase default number of iterations from 5 to 10

This commit is contained in:
Matthew Honnibal 2014-11-07 04:42:04 +11:00
parent 3cab1d9a29
commit 949a6245f9

View File

@ -10,6 +10,7 @@ import random
import json import json
import cython import cython
from .context cimport fill_slots from .context cimport fill_slots
from .context cimport fill_flat from .context cimport fill_flat
from .context cimport N_FIELDS from .context cimport N_FIELDS
@ -33,7 +34,7 @@ def setup_model_dir(tag_type, tag_names, templates, model_dir):
json.dump(config, file_) json.dump(config, file_)
def train(train_sents, model_dir, nr_iter=5): def train(train_sents, model_dir, nr_iter=10):
tagger = Tagger(model_dir) tagger = Tagger(model_dir)
for _ in range(nr_iter): for _ in range(nr_iter):
n_corr = 0 n_corr = 0
@ -43,15 +44,15 @@ def train(train_sents, model_dir, nr_iter=5):
for i, gold in enumerate(golds): for i, gold in enumerate(golds):
guess = tagger.predict(i, tokens) guess = tagger.predict(i, tokens)
tokens.set_tag(i, tagger.tag_type, guess) tokens.set_tag(i, tagger.tag_type, guess)
tagger.tell_answer(gold)
if gold != NULL_TAG: if gold != NULL_TAG:
tagger.tell_answer(gold)
total += 1 total += 1
n_corr += guess == gold n_corr += guess == gold
#print('%s\t%d\t%d' % (tokens[i].string, guess, gold)) #print('%s\t%d\t%d' % (tokens[i].string, guess, gold))
print('%.4f' % ((n_corr / total) * 100)) print('%.4f' % ((n_corr / total) * 100))
random.shuffle(train_sents) random.shuffle(train_sents)
tagger.model.end_training() tagger.model.end_training()
tagger.model.dump(path.join(model_dir, 'model'), freq_thresh=10) tagger.model.dump(path.join(model_dir, 'model'))
def evaluate(tagger, sents): def evaluate(tagger, sents):