diff --git a/spacy/cli/train.py b/spacy/cli/train.py index 7ad94ce9c..8a3446cfe 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -53,7 +53,6 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, if dev_path and not dev_path.exists(): prints(dev_path, title="Development data not found", exits=1) - lang_class = util.get_lang_class(lang) pipeline = ['token_vectors', 'tags', 'dependencies', 'entities'] if no_tagger and 'tags' in pipeline: pipeline.remove('tags') @@ -71,22 +70,22 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, util.env_opt('batch_to', 64), util.env_opt('batch_compound', 1.001)) - if resume: - prints(output_path / 'model9.pickle', title="Resuming training") - nlp = dill.load((output_path / 'model9.pickle').open('rb')) - else: + if not resume: + lang_class = util.get_lang_class(lang) nlp = lang_class(pipeline=pipeline) + else: + print("Load resume") + nlp = _resume_model(lang, pipeline) + lang_class = nlp.__class__ + corpus = GoldCorpus(train_path, dev_path, limit=n_sents) n_train_words = corpus.count_train() - optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu) nlp._optimizer = None print("Itn.\tLoss\tUAS\tNER P.\tNER R.\tNER F.\tTag %\tToken %") try: for i in range(n_iter): - if resume: - i += 20 with tqdm.tqdm(total=n_train_words, leave=False) as pbar: train_docs = corpus.train_docs(nlp, projectivize=True, noise_level=0.0, gold_preproc=gold_preproc, max_length=0) @@ -120,6 +119,17 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, dill.dump(nlp, file_, -1) +def _resume_model(lang, pipeline): + nlp = util.load_model(lang) + pipes = {getattr(pipe, 'name', None) for pipe in nlp.pipeline} + for name in pipeline: + if name not in pipes: + factory = nlp.Defaults.factories[name] + nlp.pipeline.extend(factory(nlp)) + nlp.meta['pipeline'] = pipeline + return nlp + + def _render_parses(i, to_render): to_render[0].user_data['title'] = "Batch %d" % i with Path('/tmp/entities.html').open('w') as file_: