mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Refactor train script
This commit is contained in:
parent
ffda38356a
commit
1d73dec8b1
|
@ -8,6 +8,7 @@ import cytoolz
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import dill
|
import dill
|
||||||
import tqdm
|
import tqdm
|
||||||
|
from thinc.neural._classes.model import Model
|
||||||
from thinc.neural.optimizers import linear_decay
|
from thinc.neural.optimizers import linear_decay
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
|
|
||||||
|
@ -69,18 +70,20 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
||||||
batch_sizes = util.compounding(util.env_opt('batch_from', 1),
|
batch_sizes = util.compounding(util.env_opt('batch_from', 1),
|
||||||
util.env_opt('batch_to', 64),
|
util.env_opt('batch_to', 64),
|
||||||
util.env_opt('batch_compound', 1.001))
|
util.env_opt('batch_compound', 1.001))
|
||||||
|
corpus = GoldCorpus(train_path, dev_path, limit=n_sents)
|
||||||
|
n_train_words = corpus.count_train()
|
||||||
|
|
||||||
if not resume:
|
if not resume:
|
||||||
lang_class = util.get_lang_class(lang)
|
lang_class = util.get_lang_class(lang)
|
||||||
nlp = lang_class(pipeline=pipeline)
|
nlp = lang_class(pipeline=pipeline)
|
||||||
|
optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu)
|
||||||
else:
|
else:
|
||||||
print("Load resume")
|
print("Load resume")
|
||||||
nlp = _resume_model(lang, pipeline)
|
util.use_gpu(use_gpu)
|
||||||
|
nlp = _resume_model(lang, pipeline, corpus)
|
||||||
|
optimizer = nlp.resume_training(device=use_gpu)
|
||||||
lang_class = nlp.__class__
|
lang_class = nlp.__class__
|
||||||
|
|
||||||
corpus = GoldCorpus(train_path, dev_path, limit=n_sents)
|
|
||||||
n_train_words = corpus.count_train()
|
|
||||||
optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu)
|
|
||||||
nlp._optimizer = None
|
nlp._optimizer = None
|
||||||
|
|
||||||
print("Itn.\tLoss\tUAS\tNER P.\tNER R.\tNER F.\tTag %\tToken %")
|
print("Itn.\tLoss\tUAS\tNER P.\tNER R.\tNER F.\tTag %\tToken %")
|
||||||
|
@ -101,11 +104,11 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
||||||
util.set_env_log(False)
|
util.set_env_log(False)
|
||||||
epoch_model_path = output_path / ('model%d' % i)
|
epoch_model_path = output_path / ('model%d' % i)
|
||||||
nlp.to_disk(epoch_model_path)
|
nlp.to_disk(epoch_model_path)
|
||||||
nlp_loaded = lang_class(pipeline=pipeline)
|
#nlp_loaded = lang_class(pipeline=pipeline)
|
||||||
nlp_loaded = nlp_loaded.from_disk(epoch_model_path)
|
#nlp_loaded = nlp_loaded.from_disk(epoch_model_path)
|
||||||
scorer = nlp_loaded.evaluate(
|
scorer = nlp.evaluate(
|
||||||
corpus.dev_docs(
|
corpus.dev_docs(
|
||||||
nlp_loaded,
|
nlp,
|
||||||
gold_preproc=gold_preproc))
|
gold_preproc=gold_preproc))
|
||||||
acc_loc =(output_path / ('model%d' % i) / 'accuracy.json')
|
acc_loc =(output_path / ('model%d' % i) / 'accuracy.json')
|
||||||
with acc_loc.open('w') as file_:
|
with acc_loc.open('w') as file_:
|
||||||
|
@ -114,19 +117,30 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
||||||
print_progress(i, losses, scorer.scores)
|
print_progress(i, losses, scorer.scores)
|
||||||
finally:
|
finally:
|
||||||
print("Saving model...")
|
print("Saving model...")
|
||||||
with (output_path / 'model-final.pickle').open('wb') as file_:
|
try:
|
||||||
with nlp.use_params(optimizer.averages):
|
with (output_path / 'model-final.pickle').open('wb') as file_:
|
||||||
dill.dump(nlp, file_, -1)
|
with nlp.use_params(optimizer.averages):
|
||||||
|
dill.dump(nlp, file_, -1)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _resume_model(lang, pipeline):
|
def _resume_model(lang, pipeline, corpus):
|
||||||
nlp = util.load_model(lang)
|
nlp = util.load_model(lang)
|
||||||
pipes = {getattr(pipe, 'name', None) for pipe in nlp.pipeline}
|
pipes = {getattr(pipe, 'name', None) for pipe in nlp.pipeline}
|
||||||
for name in pipeline:
|
for name in pipeline:
|
||||||
if name not in pipes:
|
if name not in pipes:
|
||||||
factory = nlp.Defaults.factories[name]
|
factory = nlp.Defaults.factories[name]
|
||||||
nlp.pipeline.extend(factory(nlp))
|
for pipe in factory(nlp):
|
||||||
|
if hasattr(pipe, 'begin_training'):
|
||||||
|
pipe.begin_training(corpus.train_tuples,
|
||||||
|
pipeline=nlp.pipeline)
|
||||||
|
nlp.pipeline.append(pipe)
|
||||||
nlp.meta['pipeline'] = pipeline
|
nlp.meta['pipeline'] = pipeline
|
||||||
|
if nlp.vocab.vectors.data.shape[1] >= 1:
|
||||||
|
nlp.vocab.vectors.data = Model.ops.asarray(
|
||||||
|
nlp.vocab.vectors.data)
|
||||||
|
|
||||||
return nlp
|
return nlp
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user