mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Update the train script, fixing GPU memory leak
This commit is contained in:
parent
836fe1d880
commit
3376d4d6e8
|
@ -17,7 +17,7 @@ from .. import displacy
|
||||||
|
|
||||||
|
|
||||||
def train(language, output_dir, train_data, dev_data, n_iter, n_sents,
|
def train(language, output_dir, train_data, dev_data, n_iter, n_sents,
|
||||||
use_gpu, tagger, parser, ner, parser_L1):
|
use_gpu, no_tagger, no_parser, no_entities, parser_L1):
|
||||||
output_path = util.ensure_path(output_dir)
|
output_path = util.ensure_path(output_dir)
|
||||||
train_path = util.ensure_path(train_data)
|
train_path = util.ensure_path(train_data)
|
||||||
dev_path = util.ensure_path(dev_data)
|
dev_path = util.ensure_path(dev_data)
|
||||||
|
@ -44,9 +44,11 @@ def train(language, output_dir, train_data, dev_data, n_iter, n_sents,
|
||||||
'lang': language,
|
'lang': language,
|
||||||
'features': lang.Defaults.tagger_features}
|
'features': lang.Defaults.tagger_features}
|
||||||
gold_train = list(read_gold_json(train_path, limit=n_sents))
|
gold_train = list(read_gold_json(train_path, limit=n_sents))
|
||||||
gold_dev = list(read_gold_json(dev_path, limit=n_sents)) if dev_path else None
|
gold_dev = list(read_gold_json(dev_path, limit=n_sents))
|
||||||
|
|
||||||
train_model(lang, gold_train, gold_dev, output_path, n_iter, use_gpu=use_gpu)
|
train_model(lang, gold_train, gold_dev, output_path, n_iter,
|
||||||
|
no_tagger=no_tagger, no_parser=no_parser, no_entities=no_entities,
|
||||||
|
use_gpu=use_gpu)
|
||||||
if gold_dev:
|
if gold_dev:
|
||||||
scorer = evaluate(lang, gold_dev, output_path)
|
scorer = evaluate(lang, gold_dev, output_path)
|
||||||
print_results(scorer)
|
print_results(scorer)
|
||||||
|
@ -65,34 +67,43 @@ def train_config(config):
|
||||||
def train_model(Language, train_data, dev_data, output_path, n_iter, **cfg):
|
def train_model(Language, train_data, dev_data, output_path, n_iter, **cfg):
|
||||||
print("Itn.\tDep. Loss\tUAS\tNER F.\tTag %\tToken %")
|
print("Itn.\tDep. Loss\tUAS\tNER F.\tTag %\tToken %")
|
||||||
|
|
||||||
nlp = Language(pipeline=['token_vectors', 'tags', 'dependencies'])
|
pipeline = ['token_vectors', 'tags', 'dependencies', 'entities']
|
||||||
|
if cfg.get('no_tagger') and 'tags' in pipeline:
|
||||||
|
pipeline.remove('tags')
|
||||||
|
if cfg.get('no_parser') and 'dependencies' in pipeline:
|
||||||
|
pipeline.remove('dependencies')
|
||||||
|
if cfg.get('no_entities') and 'entities' in pipeline:
|
||||||
|
pipeline.remove('entities')
|
||||||
|
print(pipeline)
|
||||||
|
nlp = Language(pipeline=pipeline)
|
||||||
dropout = util.env_opt('dropout', 0.0)
|
dropout = util.env_opt('dropout', 0.0)
|
||||||
# TODO: Get spaCy using Thinc's trainer and optimizer
|
# TODO: Get spaCy using Thinc's trainer and optimizer
|
||||||
with nlp.begin_training(train_data, **cfg) as (trainer, optimizer):
|
with nlp.begin_training(train_data, **cfg) as (trainer, optimizer):
|
||||||
for itn, epoch in enumerate(trainer.epochs(n_iter, gold_preproc=True)):
|
for itn, epoch in enumerate(trainer.epochs(n_iter, gold_preproc=False)):
|
||||||
losses = defaultdict(float)
|
losses = defaultdict(float)
|
||||||
to_render = []
|
|
||||||
for i, (docs, golds) in enumerate(epoch):
|
for i, (docs, golds) in enumerate(epoch):
|
||||||
state = nlp.update(docs, golds, drop=dropout, sgd=optimizer)
|
nlp.update(docs, golds, drop=dropout, sgd=optimizer)
|
||||||
losses['dep_loss'] += state.get('parser_loss', 0.0)
|
for doc in docs:
|
||||||
losses['tag_loss'] += state.get('tag_loss', 0.0)
|
doc.tensor = None
|
||||||
to_render.insert(0, nlp(docs[-1].text))
|
doc._py_tokens = []
|
||||||
to_render[0].user_data['title'] = "Batch %d" % i
|
|
||||||
with Path('/tmp/entities.html').open('w') as file_:
|
|
||||||
html = displacy.render(to_render[:5], style='ent', page=True)
|
|
||||||
file_.write(html)
|
|
||||||
with Path('/tmp/parses.html').open('w') as file_:
|
|
||||||
html = displacy.render(to_render[:5], style='dep', page=True)
|
|
||||||
file_.write(html)
|
|
||||||
if dev_data:
|
if dev_data:
|
||||||
with nlp.use_params(optimizer.averages):
|
with nlp.use_params(optimizer.averages):
|
||||||
dev_scores = trainer.evaluate(dev_data).scores
|
dev_scores = trainer.evaluate(dev_data, gold_preproc=False).scores
|
||||||
else:
|
else:
|
||||||
dev_scores = defaultdict(float)
|
dev_scores = defaultdict(float)
|
||||||
print_progress(itn, losses, dev_scores)
|
print_progress(itn, losses, dev_scores)
|
||||||
with (output_path / 'model.bin').open('wb') as file_:
|
with (output_path / 'model.bin').open('wb') as file_:
|
||||||
dill.dump(nlp, file_, -1)
|
dill.dump(nlp, file_, -1)
|
||||||
#nlp.to_disk(output_path, tokenizer=False)
|
|
||||||
|
|
||||||
|
def _render_parses(i, to_render):
|
||||||
|
to_render[0].user_data['title'] = "Batch %d" % i
|
||||||
|
with Path('/tmp/entities.html').open('w') as file_:
|
||||||
|
html = displacy.render(to_render[:5], style='ent', page=True)
|
||||||
|
file_.write(html)
|
||||||
|
with Path('/tmp/parses.html').open('w') as file_:
|
||||||
|
html = displacy.render(to_render[:5], style='dep', page=True)
|
||||||
|
file_.write(html)
|
||||||
|
|
||||||
|
|
||||||
def evaluate(Language, gold_tuples, path):
|
def evaluate(Language, gold_tuples, path):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user