mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Improve train CLI script
This commit is contained in:
parent
d21459f87d
commit
43353b5413
|
@ -28,15 +28,17 @@ from .. import displacy
|
||||||
n_iter=("number of iterations", "option", "n", int),
|
n_iter=("number of iterations", "option", "n", int),
|
||||||
n_sents=("number of sentences", "option", "ns", int),
|
n_sents=("number of sentences", "option", "ns", int),
|
||||||
use_gpu=("Use GPU", "flag", "G", bool),
|
use_gpu=("Use GPU", "flag", "G", bool),
|
||||||
|
resume=("Whether to resume training", "flag", "R", bool),
|
||||||
no_tagger=("Don't train tagger", "flag", "T", bool),
|
no_tagger=("Don't train tagger", "flag", "T", bool),
|
||||||
no_parser=("Don't train parser", "flag", "P", bool),
|
no_parser=("Don't train parser", "flag", "P", bool),
|
||||||
no_entities=("Don't train NER", "flag", "N", bool)
|
no_entities=("Don't train NER", "flag", "N", bool)
|
||||||
)
|
)
|
||||||
def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
||||||
use_gpu=False, no_tagger=False, no_parser=False, no_entities=False):
|
use_gpu=False, resume=False, no_tagger=False, no_parser=False, no_entities=False):
|
||||||
"""
|
"""
|
||||||
Train a model. Expects data in spaCy's JSON format.
|
Train a model. Expects data in spaCy's JSON format.
|
||||||
"""
|
"""
|
||||||
|
util.set_env_log(True)
|
||||||
n_sents = n_sents or None
|
n_sents = n_sents or None
|
||||||
output_path = util.ensure_path(output_dir)
|
output_path = util.ensure_path(output_dir)
|
||||||
train_path = util.ensure_path(train_data)
|
train_path = util.ensure_path(train_data)
|
||||||
|
@ -66,6 +68,10 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
||||||
util.env_opt('batch_to', 64),
|
util.env_opt('batch_to', 64),
|
||||||
util.env_opt('batch_compound', 1.001))
|
util.env_opt('batch_compound', 1.001))
|
||||||
|
|
||||||
|
if resume:
|
||||||
|
prints(output_path / 'model19.pickle', title="Resuming training")
|
||||||
|
nlp = dill.load((output_path / 'model19.pickle').open('rb'))
|
||||||
|
else:
|
||||||
nlp = lang_class(pipeline=pipeline)
|
nlp = lang_class(pipeline=pipeline)
|
||||||
corpus = GoldCorpus(train_path, dev_path, limit=n_sents)
|
corpus = GoldCorpus(train_path, dev_path, limit=n_sents)
|
||||||
n_train_docs = corpus.count_train()
|
n_train_docs = corpus.count_train()
|
||||||
|
@ -75,6 +81,8 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
||||||
print("Itn.\tLoss\tUAS\tNER P.\tNER R.\tNER F.\tTag %\tToken %")
|
print("Itn.\tLoss\tUAS\tNER P.\tNER R.\tNER F.\tTag %\tToken %")
|
||||||
try:
|
try:
|
||||||
for i in range(n_iter):
|
for i in range(n_iter):
|
||||||
|
if resume:
|
||||||
|
i += 20
|
||||||
with tqdm.tqdm(total=corpus.count_train(), leave=False) as pbar:
|
with tqdm.tqdm(total=corpus.count_train(), leave=False) as pbar:
|
||||||
train_docs = corpus.train_docs(nlp, projectivize=True,
|
train_docs = corpus.train_docs(nlp, projectivize=True,
|
||||||
gold_preproc=False, max_length=0)
|
gold_preproc=False, max_length=0)
|
||||||
|
@ -86,14 +94,18 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
||||||
pbar.update(len(docs))
|
pbar.update(len(docs))
|
||||||
|
|
||||||
with nlp.use_params(optimizer.averages):
|
with nlp.use_params(optimizer.averages):
|
||||||
|
util.set_env_log(False)
|
||||||
|
epoch_model_path = output_path / ('model%d' % i)
|
||||||
|
nlp.to_disk(epoch_model_path)
|
||||||
with (output_path / ('model%d.pickle' % i)).open('wb') as file_:
|
with (output_path / ('model%d.pickle' % i)).open('wb') as file_:
|
||||||
dill.dump(nlp, file_, -1)
|
dill.dump(nlp, file_, -1)
|
||||||
with (output_path / ('model%d.bin' % i)).open('wb') as file_:
|
|
||||||
file_.write(nlp.to_bytes())
|
|
||||||
with (output_path / ('model%d.bin' % i)).open('rb') as file_:
|
|
||||||
nlp_loaded = lang_class(pipeline=pipeline)
|
nlp_loaded = lang_class(pipeline=pipeline)
|
||||||
nlp_loaded.from_bytes(file_.read())
|
nlp_loaded = nlp_loaded.from_disk(epoch_model_path)
|
||||||
scorer = nlp_loaded.evaluate(corpus.dev_docs(nlp_loaded, gold_preproc=False))
|
scorer = nlp_loaded.evaluate(
|
||||||
|
corpus.dev_docs(
|
||||||
|
nlp_loaded,
|
||||||
|
gold_preproc=False))
|
||||||
|
util.set_env_log(True)
|
||||||
print_progress(i, losses, scorer.scores)
|
print_progress(i, losses, scorer.scores)
|
||||||
finally:
|
finally:
|
||||||
print("Saving model...")
|
print("Saving model...")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user