mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 01:16:28 +03:00
Refactor training, to fix memory leak
This commit is contained in:
parent
4803b3b69e
commit
4c9202249d
|
@ -129,9 +129,31 @@ class CLI(object):
|
||||||
print("\n Command %r does not exist."
|
print("\n Command %r does not exist."
|
||||||
"\n Use the --help flag for a list of available commands.\n" % name)
|
"\n Use the --help flag for a list of available commands.\n" % name)
|
||||||
|
|
||||||
|
@plac.annotations(
|
||||||
|
lang=("model language", "positional", None, str),
|
||||||
|
output_dir=("output directory to store model in", "positional", None, str),
|
||||||
|
train_data=("location of JSON-formatted training data", "positional", None, str),
|
||||||
|
dev_data=("location of JSON-formatted development data (optional)", "positional", None, str),
|
||||||
|
n_iter=("number of iterations", "option", "n", int),
|
||||||
|
nsents=("number of sentences", "option", None, int),
|
||||||
|
use_gpu=("Use GPU", "flag", "g", bool),
|
||||||
|
no_tagger=("Don't train tagger", "flag", "T", bool),
|
||||||
|
no_parser=("Don't train parser", "flag", "P", bool),
|
||||||
|
no_entities=("Don't train NER", "flag", "N", bool)
|
||||||
|
)
|
||||||
|
def train(self, lang, output_dir, train_data, dev_data=None, n_iter=15,
|
||||||
|
nsents=0, use_gpu=False,
|
||||||
|
no_tagger=False, no_parser=False, no_entities=False):
|
||||||
|
"""
|
||||||
|
Train a model. Expects data in spaCy's JSON format.
|
||||||
|
"""
|
||||||
|
nsents = nsents or None
|
||||||
|
cli_train(lang, output_dir, train_data, dev_data, n_iter, nsents,
|
||||||
|
use_gpu, no_tagger, no_parser, no_entities)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
import plac
|
import plac
|
||||||
import sys
|
import sys
|
||||||
sys.argv[0] = 'spacy'
|
if sys.argv[1] == 'train':
|
||||||
plac.Interpreter.call(CLI)
|
plac.call(train)
|
||||||
|
|
|
@ -6,18 +6,19 @@ from collections import defaultdict
|
||||||
import cytoolz
|
import cytoolz
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import dill
|
import dill
|
||||||
|
import tqdm
|
||||||
|
|
||||||
from ..tokens.doc import Doc
|
from ..tokens.doc import Doc
|
||||||
from ..scorer import Scorer
|
from ..scorer import Scorer
|
||||||
from ..gold import GoldParse, merge_sents
|
from ..gold import GoldParse, merge_sents
|
||||||
from ..gold import read_json_file as read_gold_json
|
from ..gold import GoldCorpus
|
||||||
from ..util import prints
|
from ..util import prints
|
||||||
from .. import util
|
from .. import util
|
||||||
from .. import displacy
|
from .. import displacy
|
||||||
|
|
||||||
|
|
||||||
def train(language, output_dir, train_data, dev_data, n_iter, n_sents,
|
def train(lang_id, output_dir, train_data, dev_data, n_iter, n_sents,
|
||||||
use_gpu, no_tagger, no_parser, no_entities, parser_L1):
|
use_gpu, no_tagger, no_parser, no_entities):
|
||||||
output_path = util.ensure_path(output_dir)
|
output_path = util.ensure_path(output_dir)
|
||||||
train_path = util.ensure_path(train_data)
|
train_path = util.ensure_path(train_data)
|
||||||
dev_path = util.ensure_path(dev_data)
|
dev_path = util.ensure_path(dev_data)
|
||||||
|
@ -28,70 +29,32 @@ def train(language, output_dir, train_data, dev_data, n_iter, n_sents,
|
||||||
if dev_path and not dev_path.exists():
|
if dev_path and not dev_path.exists():
|
||||||
prints(dev_path, title="Development data not found", exits=True)
|
prints(dev_path, title="Development data not found", exits=True)
|
||||||
|
|
||||||
lang = util.get_lang_class(language)
|
lang_class = util.get_lang_class(lang_id)
|
||||||
parser_cfg = {
|
|
||||||
'pseudoprojective': True,
|
|
||||||
'L1': parser_L1,
|
|
||||||
'n_iter': n_iter,
|
|
||||||
'lang': language,
|
|
||||||
'features': lang.Defaults.parser_features}
|
|
||||||
entity_cfg = {
|
|
||||||
'n_iter': n_iter,
|
|
||||||
'lang': language,
|
|
||||||
'features': lang.Defaults.entity_features}
|
|
||||||
tagger_cfg = {
|
|
||||||
'n_iter': n_iter,
|
|
||||||
'lang': language,
|
|
||||||
'features': lang.Defaults.tagger_features}
|
|
||||||
gold_train = list(read_gold_json(train_path, limit=n_sents))
|
|
||||||
gold_dev = list(read_gold_json(dev_path, limit=n_sents))
|
|
||||||
|
|
||||||
train_model(lang, gold_train, gold_dev, output_path, n_iter,
|
|
||||||
no_tagger=no_tagger, no_parser=no_parser, no_entities=no_entities,
|
|
||||||
use_gpu=use_gpu)
|
|
||||||
if gold_dev:
|
|
||||||
scorer = evaluate(lang, gold_dev, output_path)
|
|
||||||
print_results(scorer)
|
|
||||||
|
|
||||||
|
|
||||||
def train_config(config):
|
|
||||||
config_path = util.ensure_path(config)
|
|
||||||
if not config_path.is_file():
|
|
||||||
prints(config_path, title="Config file not found", exits=True)
|
|
||||||
config = json.load(config_path)
|
|
||||||
for setting in []:
|
|
||||||
if setting not in config.keys():
|
|
||||||
prints("%s not found in config file." % setting, title="Missing setting")
|
|
||||||
|
|
||||||
|
|
||||||
def train_model(Language, train_data, dev_data, output_path, n_iter, **cfg):
|
|
||||||
print("Itn.\tDep. Loss\tUAS\tNER F.\tTag %\tToken %")
|
|
||||||
|
|
||||||
pipeline = ['token_vectors', 'tags', 'dependencies', 'entities']
|
pipeline = ['token_vectors', 'tags', 'dependencies', 'entities']
|
||||||
if cfg.get('no_tagger') and 'tags' in pipeline:
|
if no_tagger and 'tags' in pipeline: pipeline.remove('tags')
|
||||||
pipeline.remove('tags')
|
if no_parser and 'dependencies' in pipeline: pipeline.remove('dependencies')
|
||||||
if cfg.get('no_parser') and 'dependencies' in pipeline:
|
if no_entities and 'entities' in pipeline: pipeline.remove('entities')
|
||||||
pipeline.remove('dependencies')
|
|
||||||
if cfg.get('no_entities') and 'entities' in pipeline:
|
nlp = lang_class(pipeline=pipeline)
|
||||||
pipeline.remove('entities')
|
corpus = GoldCorpus(train_path, dev_path)
|
||||||
print(pipeline)
|
|
||||||
nlp = Language(pipeline=pipeline)
|
|
||||||
dropout = util.env_opt('dropout', 0.0)
|
dropout = util.env_opt('dropout', 0.0)
|
||||||
# TODO: Get spaCy using Thinc's trainer and optimizer
|
|
||||||
with nlp.begin_training(train_data, **cfg) as (trainer, optimizer):
|
optimizer = nlp.begin_training(lambda: corpus.train_tuples, use_gpu=use_gpu)
|
||||||
for itn, epoch in enumerate(trainer.epochs(n_iter, gold_preproc=False)):
|
n_train_docs = corpus.count_train()
|
||||||
losses = defaultdict(float)
|
print("Itn.\tDep. Loss\tUAS\tNER F.\tTag %\tToken %")
|
||||||
for i, (docs, golds) in enumerate(epoch):
|
for i in range(n_iter):
|
||||||
|
with tqdm.tqdm(total=n_train_docs) as pbar:
|
||||||
|
train_docs = corpus.train_docs(nlp, shuffle=i)
|
||||||
|
for batch in cytoolz.partition_all(20, train_docs):
|
||||||
|
docs, golds = zip(*batch)
|
||||||
|
docs = list(docs)
|
||||||
|
golds = list(golds)
|
||||||
nlp.update(docs, golds, drop=dropout, sgd=optimizer)
|
nlp.update(docs, golds, drop=dropout, sgd=optimizer)
|
||||||
for doc in docs:
|
pbar.update(len(docs))
|
||||||
doc.tensor = None
|
scorer = nlp.evaluate(corpus.dev_docs(nlp))
|
||||||
doc._py_tokens = []
|
print_progress(i, {}, scorer.scores)
|
||||||
if dev_data:
|
|
||||||
with nlp.use_params(optimizer.averages):
|
|
||||||
dev_scores = trainer.evaluate(dev_data, gold_preproc=False).scores
|
|
||||||
else:
|
|
||||||
dev_scores = defaultdict(float)
|
|
||||||
print_progress(itn, losses, dev_scores)
|
|
||||||
with (output_path / 'model.bin').open('wb') as file_:
|
with (output_path / 'model.bin').open('wb') as file_:
|
||||||
dill.dump(nlp, file_, -1)
|
dill.dump(nlp, file_, -1)
|
||||||
|
|
||||||
|
|
|
@ -6,12 +6,12 @@ import dill
|
||||||
import numpy
|
import numpy
|
||||||
from thinc.neural import Model
|
from thinc.neural import Model
|
||||||
from thinc.neural.ops import NumpyOps, CupyOps
|
from thinc.neural.ops import NumpyOps, CupyOps
|
||||||
|
from thinc.neural.optimizers import Adam
|
||||||
|
|
||||||
from .tokenizer import Tokenizer
|
from .tokenizer import Tokenizer
|
||||||
from .vocab import Vocab
|
from .vocab import Vocab
|
||||||
from .tagger import Tagger
|
from .tagger import Tagger
|
||||||
from .lemmatizer import Lemmatizer
|
from .lemmatizer import Lemmatizer
|
||||||
from .train import Trainer
|
|
||||||
from .syntax.parser import get_templates
|
from .syntax.parser import get_templates
|
||||||
from .syntax.nonproj import PseudoProjectivity
|
from .syntax.nonproj import PseudoProjectivity
|
||||||
from .pipeline import NeuralDependencyParser, EntityRecognizer
|
from .pipeline import NeuralDependencyParser, EntityRecognizer
|
||||||
|
@ -23,6 +23,7 @@ from .lang.tokenizer_exceptions import TOKEN_MATCH
|
||||||
from .lang.tag_map import TAG_MAP
|
from .lang.tag_map import TAG_MAP
|
||||||
from .lang.lex_attrs import LEX_ATTRS
|
from .lang.lex_attrs import LEX_ATTRS
|
||||||
from . import util
|
from . import util
|
||||||
|
from .scorer import Scorer
|
||||||
|
|
||||||
|
|
||||||
class BaseDefaults(object):
|
class BaseDefaults(object):
|
||||||
|
@ -181,8 +182,8 @@ class Language(object):
|
||||||
for proc in self.pipeline[1:]:
|
for proc in self.pipeline[1:]:
|
||||||
grads = {}
|
grads = {}
|
||||||
tokvecses, bp_tokvecses = tok2vec.model.begin_update(feats, drop=drop)
|
tokvecses, bp_tokvecses = tok2vec.model.begin_update(feats, drop=drop)
|
||||||
d_tokvecses = proc.update((docs, tokvecses), golds, sgd=get_grads, drop=drop)
|
d_tokvecses = proc.update((docs, tokvecses), golds, sgd=sgd, drop=drop)
|
||||||
bp_tokvecses(d_tokvecses, sgd=get_grads)
|
bp_tokvecses(d_tokvecses, sgd=sgd)
|
||||||
if sgd is not None:
|
if sgd is not None:
|
||||||
for key, (W, dW) in grads.items():
|
for key, (W, dW) in grads.items():
|
||||||
# TODO: Unhack this when thinc improves
|
# TODO: Unhack this when thinc improves
|
||||||
|
@ -191,16 +192,24 @@ class Language(object):
|
||||||
else:
|
else:
|
||||||
sgd.ops = CupyOps()
|
sgd.ops = CupyOps()
|
||||||
sgd(W, dW, key=key)
|
sgd(W, dW, key=key)
|
||||||
|
for key in list(grads.keys()):
|
||||||
|
grads.pop(key)
|
||||||
|
for doc in docs:
|
||||||
|
doc.tensor = None
|
||||||
|
|
||||||
@contextmanager
|
def preprocess_gold(self, docs_golds):
|
||||||
def begin_training(self, gold_tuples, **cfg):
|
for proc in self.pipeline:
|
||||||
|
if hasattr(proc, 'preprocess_gold'):
|
||||||
|
docs_golds = proc.preprocess_gold(docs_golds)
|
||||||
|
for doc, gold in docs_golds:
|
||||||
|
yield doc, gold
|
||||||
|
|
||||||
|
def begin_training(self, get_gold_tuples, **cfg):
|
||||||
# Populate vocab
|
# Populate vocab
|
||||||
for _, annots_brackets in gold_tuples:
|
for _, annots_brackets in get_gold_tuples():
|
||||||
for annots, _ in annots_brackets:
|
for annots, _ in annots_brackets:
|
||||||
for word in annots[1]:
|
for word in annots[1]:
|
||||||
_ = self.vocab[word]
|
_ = self.vocab[word]
|
||||||
# Handle crossing dependencies
|
|
||||||
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
|
|
||||||
contexts = []
|
contexts = []
|
||||||
if cfg.get('use_gpu'):
|
if cfg.get('use_gpu'):
|
||||||
Model.ops = CupyOps()
|
Model.ops = CupyOps()
|
||||||
|
@ -208,11 +217,18 @@ class Language(object):
|
||||||
print("Use GPU")
|
print("Use GPU")
|
||||||
for proc in self.pipeline:
|
for proc in self.pipeline:
|
||||||
if hasattr(proc, 'begin_training'):
|
if hasattr(proc, 'begin_training'):
|
||||||
context = proc.begin_training(gold_tuples,
|
context = proc.begin_training(get_gold_tuples(),
|
||||||
pipeline=self.pipeline)
|
pipeline=self.pipeline)
|
||||||
contexts.append(context)
|
contexts.append(context)
|
||||||
trainer = Trainer(self, gold_tuples, **cfg)
|
optimizer = Adam(Model.ops, 0.001)
|
||||||
yield trainer, trainer.optimizer
|
return optimizer
|
||||||
|
|
||||||
|
def evaluate(self, docs_golds):
|
||||||
|
docs, golds = zip(*docs_golds)
|
||||||
|
scorer = Scorer()
|
||||||
|
for doc, gold in zip(self.pipe(docs), golds):
|
||||||
|
scorer.score(doc, gold)
|
||||||
|
return scorer
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def use_params(self, params, **cfg):
|
def use_params(self, params, **cfg):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user