mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Improve train CLI
This commit is contained in:
parent
a053b1218e
commit
c52fde40f4
|
@ -18,6 +18,7 @@ from ..gold import GoldCorpus, minibatch
|
||||||
from ..util import prints
|
from ..util import prints
|
||||||
from .. import util
|
from .. import util
|
||||||
from .. import displacy
|
from .. import displacy
|
||||||
|
from ..compat import json_dumps
|
||||||
|
|
||||||
|
|
||||||
@plac.annotations(
|
@plac.annotations(
|
||||||
|
@ -44,7 +45,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
||||||
train_path = util.ensure_path(train_data)
|
train_path = util.ensure_path(train_data)
|
||||||
dev_path = util.ensure_path(dev_data)
|
dev_path = util.ensure_path(dev_data)
|
||||||
if not output_path.exists():
|
if not output_path.exists():
|
||||||
prints(output_path, title="Output directory not found", exits=1)
|
output_path.mkdir()
|
||||||
if not train_path.exists():
|
if not train_path.exists():
|
||||||
prints(train_path, title="Training data not found", exits=1)
|
prints(train_path, title="Training data not found", exits=1)
|
||||||
if dev_path and not dev_path.exists():
|
if dev_path and not dev_path.exists():
|
||||||
|
@ -74,7 +75,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
||||||
else:
|
else:
|
||||||
nlp = lang_class(pipeline=pipeline)
|
nlp = lang_class(pipeline=pipeline)
|
||||||
corpus = GoldCorpus(train_path, dev_path, limit=n_sents)
|
corpus = GoldCorpus(train_path, dev_path, limit=n_sents)
|
||||||
n_train_docs = corpus.count_train()
|
n_train_words = corpus.count_train()
|
||||||
|
|
||||||
optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu)
|
optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu)
|
||||||
|
|
||||||
|
@ -83,7 +84,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
||||||
for i in range(n_iter):
|
for i in range(n_iter):
|
||||||
if resume:
|
if resume:
|
||||||
i += 20
|
i += 20
|
||||||
with tqdm.tqdm(total=corpus.count_train(), leave=False) as pbar:
|
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
||||||
train_docs = corpus.train_docs(nlp, projectivize=True,
|
train_docs = corpus.train_docs(nlp, projectivize=True,
|
||||||
gold_preproc=False, max_length=0)
|
gold_preproc=False, max_length=0)
|
||||||
losses = {}
|
losses = {}
|
||||||
|
@ -91,7 +92,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
||||||
docs, golds = zip(*batch)
|
docs, golds = zip(*batch)
|
||||||
nlp.update(docs, golds, sgd=optimizer,
|
nlp.update(docs, golds, sgd=optimizer,
|
||||||
drop=next(dropout_rates), losses=losses)
|
drop=next(dropout_rates), losses=losses)
|
||||||
pbar.update(len(docs))
|
pbar.update(sum(len(doc) for doc in docs))
|
||||||
|
|
||||||
with nlp.use_params(optimizer.averages):
|
with nlp.use_params(optimizer.averages):
|
||||||
util.set_env_log(False)
|
util.set_env_log(False)
|
||||||
|
@ -105,6 +106,9 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
||||||
corpus.dev_docs(
|
corpus.dev_docs(
|
||||||
nlp_loaded,
|
nlp_loaded,
|
||||||
gold_preproc=False))
|
gold_preproc=False))
|
||||||
|
acc_loc =(output_path / ('model%d' % i) / 'accuracy.json')
|
||||||
|
with acc_loc.open('w') as file_:
|
||||||
|
file_.write(json_dumps(scorer.scores))
|
||||||
util.set_env_log(True)
|
util.set_env_log(True)
|
||||||
print_progress(i, losses, scorer.scores)
|
print_progress(i, losses, scorer.scores)
|
||||||
finally:
|
finally:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user