mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 09:26:27 +03:00
Support resuming a model during spacy train
This commit is contained in:
parent
c858927271
commit
a0c4b33d03
|
@ -53,7 +53,6 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
||||||
if dev_path and not dev_path.exists():
|
if dev_path and not dev_path.exists():
|
||||||
prints(dev_path, title="Development data not found", exits=1)
|
prints(dev_path, title="Development data not found", exits=1)
|
||||||
|
|
||||||
lang_class = util.get_lang_class(lang)
|
|
||||||
|
|
||||||
pipeline = ['token_vectors', 'tags', 'dependencies', 'entities']
|
pipeline = ['token_vectors', 'tags', 'dependencies', 'entities']
|
||||||
if no_tagger and 'tags' in pipeline: pipeline.remove('tags')
|
if no_tagger and 'tags' in pipeline: pipeline.remove('tags')
|
||||||
|
@ -71,22 +70,22 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
||||||
util.env_opt('batch_to', 64),
|
util.env_opt('batch_to', 64),
|
||||||
util.env_opt('batch_compound', 1.001))
|
util.env_opt('batch_compound', 1.001))
|
||||||
|
|
||||||
if resume:
|
if not resume:
|
||||||
prints(output_path / 'model9.pickle', title="Resuming training")
|
lang_class = util.get_lang_class(lang)
|
||||||
nlp = dill.load((output_path / 'model9.pickle').open('rb'))
|
|
||||||
else:
|
|
||||||
nlp = lang_class(pipeline=pipeline)
|
nlp = lang_class(pipeline=pipeline)
|
||||||
|
else:
|
||||||
|
print("Load resume")
|
||||||
|
nlp = _resume_model(lang, pipeline)
|
||||||
|
lang_class = nlp.__class__
|
||||||
|
|
||||||
corpus = GoldCorpus(train_path, dev_path, limit=n_sents)
|
corpus = GoldCorpus(train_path, dev_path, limit=n_sents)
|
||||||
n_train_words = corpus.count_train()
|
n_train_words = corpus.count_train()
|
||||||
|
|
||||||
optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu)
|
optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu)
|
||||||
nlp._optimizer = None
|
nlp._optimizer = None
|
||||||
|
|
||||||
print("Itn.\tLoss\tUAS\tNER P.\tNER R.\tNER F.\tTag %\tToken %")
|
print("Itn.\tLoss\tUAS\tNER P.\tNER R.\tNER F.\tTag %\tToken %")
|
||||||
try:
|
try:
|
||||||
for i in range(n_iter):
|
for i in range(n_iter):
|
||||||
if resume:
|
|
||||||
i += 20
|
|
||||||
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
||||||
train_docs = corpus.train_docs(nlp, projectivize=True, noise_level=0.0,
|
train_docs = corpus.train_docs(nlp, projectivize=True, noise_level=0.0,
|
||||||
gold_preproc=gold_preproc, max_length=0)
|
gold_preproc=gold_preproc, max_length=0)
|
||||||
|
@ -120,6 +119,17 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
||||||
dill.dump(nlp, file_, -1)
|
dill.dump(nlp, file_, -1)
|
||||||
|
|
||||||
|
|
||||||
|
def _resume_model(lang, pipeline):
|
||||||
|
nlp = util.load_model(lang)
|
||||||
|
pipes = {getattr(pipe, 'name', None) for pipe in nlp.pipeline}
|
||||||
|
for name in pipeline:
|
||||||
|
if name not in pipes:
|
||||||
|
factory = nlp.Defaults.factories[name]
|
||||||
|
nlp.pipeline.extend(factory(nlp))
|
||||||
|
nlp.meta['pipeline'] = pipeline
|
||||||
|
return nlp
|
||||||
|
|
||||||
|
|
||||||
def _render_parses(i, to_render):
|
def _render_parses(i, to_render):
|
||||||
to_render[0].user_data['title'] = "Batch %d" % i
|
to_render[0].user_data['title'] = "Batch %d" % i
|
||||||
with Path('/tmp/entities.html').open('w') as file_:
|
with Path('/tmp/entities.html').open('w') as file_:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user