Patch spacy.train for new pipeline management

This commit is contained in:
Matthew Honnibal 2017-10-09 23:41:16 -05:00
parent a635240398
commit 97c9b5db8b

View File

@ -88,9 +88,11 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=10, n_sents=0,
n_train_words = corpus.count_train() n_train_words = corpus.count_train()
lang_class = util.get_lang_class(lang) lang_class = util.get_lang_class(lang)
nlp = lang_class(pipeline=pipeline) nlp = lang_class()
if vectors: if vectors:
util.load_model(vectors, vocab=nlp.vocab) util.load_model(vectors, vocab=nlp.vocab)
for name in pipeline:
nlp.add_pipe(nlp.create_pipe(name), name=name)
optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu) optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu)
nlp._optimizer = None nlp._optimizer = None
@ -113,6 +115,8 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=10, n_sents=0,
epoch_model_path = output_path / ('model%d' % i) epoch_model_path = output_path / ('model%d' % i)
nlp.to_disk(epoch_model_path) nlp.to_disk(epoch_model_path)
nlp_loaded = lang_class(pipeline=pipeline) nlp_loaded = lang_class(pipeline=pipeline)
for name in pipeline:
nlp_loaded.add_pipe(nlp.create_pipe(name), name=name)
nlp_loaded = nlp_loaded.from_disk(epoch_model_path) nlp_loaded = nlp_loaded.from_disk(epoch_model_path)
dev_docs = list(corpus.dev_docs( dev_docs = list(corpus.dev_docs(
nlp_loaded, nlp_loaded,
@ -128,6 +132,9 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=10, n_sents=0,
gpu_wps = nwords/(end_time-start_time) gpu_wps = nwords/(end_time-start_time)
with Model.use_device('cpu'): with Model.use_device('cpu'):
nlp_loaded = lang_class(pipeline=pipeline) nlp_loaded = lang_class(pipeline=pipeline)
for name in pipeline:
nlp_loaded.add_pipe(nlp.create_pipe(name), name=name)
nlp_loaded = nlp_loaded.from_disk(epoch_model_path) nlp_loaded = nlp_loaded.from_disk(epoch_model_path)
dev_docs = list(corpus.dev_docs( dev_docs = list(corpus.dev_docs(
nlp_loaded, gold_preproc=gold_preproc)) nlp_loaded, gold_preproc=gold_preproc))