mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-12 04:38:28 +03:00
* Update conll_train script
This commit is contained in:
parent
fab538672e
commit
77f2b218f9
|
@ -87,15 +87,15 @@ def _parse_line(line):
|
|||
|
||||
|
||||
def score_model(nlp, gold_tuples, verbose=False):
|
||||
scorer = Scorer()
|
||||
correct = 0.0
|
||||
total = 0.0
|
||||
for words, gold_tags in gold_tuples:
|
||||
tokens = nlp.tokenizer.tokens_from_list(words)
|
||||
nlp.tagger(tokens)
|
||||
for token, gold in zip(tokens, gold_tags):
|
||||
scorer.tags.tp += token.tag_ == gold
|
||||
scorer.tags.fp += token.tag_ != gold
|
||||
scorer.tags.fn += token.tag_ != gold
|
||||
return scorer.tags_acc
|
||||
correct += token.tag_ == gold
|
||||
total += 1
|
||||
return (correct / total) * 100
|
||||
|
||||
|
||||
def train(Language, train_sents, dev_sents, model_dir, n_iter=15, seed=0,
|
||||
|
@ -116,8 +116,6 @@ def train(Language, train_sents, dev_sents, model_dir, n_iter=15, seed=0,
|
|||
random.shuffle(train_sents)
|
||||
heldout_sents = train_sents[:int(nr_train * 0.1)]
|
||||
train_sents = train_sents[len(heldout_sents):]
|
||||
#train_sents = train_sents[:500]
|
||||
#assert len(heldout_sents) < len(train_sents)
|
||||
prev_score = 0.0
|
||||
variance = 0.001
|
||||
last_good_learn_rate = nlp.tagger.model.eta
|
||||
|
@ -130,15 +128,15 @@ def train(Language, train_sents, dev_sents, model_dir, n_iter=15, seed=0,
|
|||
acc += nlp.tagger.train(tokens, gold_tags)
|
||||
total += len(tokens)
|
||||
n += 1
|
||||
if n and n % 10000 == 0:
|
||||
if n and n % 20000 == 0:
|
||||
dev_score = score_model(nlp, heldout_sents)
|
||||
eval_score = score_model(nlp, dev_sents)
|
||||
if dev_score > prev_score:
|
||||
if dev_score >= prev_score:
|
||||
nlp.tagger.model.keep_update()
|
||||
prev_score = dev_score
|
||||
variance = 0.001
|
||||
last_good_learn_rate = nlp.tagger.model.eta
|
||||
nlp.tagger.model.eta *= 1.05
|
||||
nlp.tagger.model.eta *= 1.01
|
||||
print('%d:\t%.3f\t%.3f\t%.3f\t%.4f' % (n, acc/total, dev_score, eval_score, nlp.tagger.model.eta))
|
||||
else:
|
||||
nlp.tagger.model.backtrack()
|
||||
|
|
Loading…
Reference in New Issue
Block a user