mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 02:36:32 +03:00
Save models properly in conll_train.py
This commit is contained in:
parent
5869f05bd6
commit
e38632003d
|
@ -171,8 +171,7 @@ def train(Language, gold_tuples, model_dir, dev_loc, n_iter=15, feat_set=u'basic
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("Saving model...")
|
print("Saving model...")
|
||||||
break
|
break
|
||||||
#nlp.end_training(model_dir)
|
nlp.end_training(model_dir)
|
||||||
nlp.parser.model.end_training()
|
|
||||||
print("Saved. Evaluating...")
|
print("Saved. Evaluating...")
|
||||||
return nlp
|
return nlp
|
||||||
|
|
||||||
|
@ -215,7 +214,7 @@ def main(train_loc, dev_loc, model_dir, n_iter=15, neural=False, batch_norm=Fals
|
||||||
learn_rate=0.001, update_step='sgd_cm'):
|
learn_rate=0.001, update_step='sgd_cm'):
|
||||||
with io.open(train_loc, 'r', encoding='utf8') as file_:
|
with io.open(train_loc, 'r', encoding='utf8') as file_:
|
||||||
train_sents = list(read_conll(file_))
|
train_sents = list(read_conll(file_))
|
||||||
# preprocess training data here before ArcEager.get_labels() is called
|
# Preprocess training data here before ArcEager.get_labels() is called
|
||||||
train_sents = PseudoProjectivity.preprocess_training_data(train_sents)
|
train_sents = PseudoProjectivity.preprocess_training_data(train_sents)
|
||||||
|
|
||||||
nlp = train(English, train_sents, model_dir, dev_loc, n_iter=n_iter,
|
nlp = train(English, train_sents, model_dir, dev_loc, n_iter=n_iter,
|
||||||
|
@ -225,11 +224,6 @@ def main(train_loc, dev_loc, model_dir, n_iter=15, neural=False, batch_norm=Fals
|
||||||
update_step=update_step)
|
update_step=update_step)
|
||||||
|
|
||||||
scorer = score_file(nlp, dev_loc)
|
scorer = score_file(nlp, dev_loc)
|
||||||
#scorer = Scorer()
|
|
||||||
#with io.open(dev_loc, 'r', encoding='utf8') as file_:
|
|
||||||
# for _, sents in read_conll(file_):
|
|
||||||
# for annot_tuples, _ in sents:
|
|
||||||
# score_model(scorer, nlp, None, annot_tuples)
|
|
||||||
print('TOK', scorer.token_acc)
|
print('TOK', scorer.token_acc)
|
||||||
print('POS', scorer.tags_acc)
|
print('POS', scorer.tags_acc)
|
||||||
print('UAS', scorer.uas)
|
print('UAS', scorer.uas)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user