mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Save models properly in conll_train.py
This commit is contained in:
parent
5869f05bd6
commit
e38632003d
|
@ -171,8 +171,7 @@ def train(Language, gold_tuples, model_dir, dev_loc, n_iter=15, feat_set=u'basic
|
|||
except KeyboardInterrupt:
|
||||
print("Saving model...")
|
||||
break
|
||||
#nlp.end_training(model_dir)
|
||||
nlp.parser.model.end_training()
|
||||
nlp.end_training(model_dir)
|
||||
print("Saved. Evaluating...")
|
||||
return nlp
|
||||
|
||||
|
@ -215,7 +214,7 @@ def main(train_loc, dev_loc, model_dir, n_iter=15, neural=False, batch_norm=Fals
|
|||
learn_rate=0.001, update_step='sgd_cm'):
|
||||
with io.open(train_loc, 'r', encoding='utf8') as file_:
|
||||
train_sents = list(read_conll(file_))
|
||||
# preprocess training data here before ArcEager.get_labels() is called
|
||||
# Preprocess training data here before ArcEager.get_labels() is called
|
||||
train_sents = PseudoProjectivity.preprocess_training_data(train_sents)
|
||||
|
||||
nlp = train(English, train_sents, model_dir, dev_loc, n_iter=n_iter,
|
||||
|
@ -225,11 +224,6 @@ def main(train_loc, dev_loc, model_dir, n_iter=15, neural=False, batch_norm=Fals
|
|||
update_step=update_step)
|
||||
|
||||
scorer = score_file(nlp, dev_loc)
|
||||
#scorer = Scorer()
|
||||
#with io.open(dev_loc, 'r', encoding='utf8') as file_:
|
||||
# for _, sents in read_conll(file_):
|
||||
# for annot_tuples, _ in sents:
|
||||
# score_model(scorer, nlp, None, annot_tuples)
|
||||
print('TOK', scorer.token_acc)
|
||||
print('POS', scorer.tags_acc)
|
||||
print('UAS', scorer.uas)
|
||||
|
|
Loading…
Reference in New Issue
Block a user