mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-15 20:16:23 +03:00
* Wire hyperparameters to script interface
This commit is contained in:
parent
ebe630cc8d
commit
da793073d0
|
@ -84,7 +84,8 @@ def _merge_sents(sents):
|
|||
def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
||||
seed=0, gold_preproc=False, n_sents=0, corruption_level=0,
|
||||
verbose=False,
|
||||
eta=0.01, mu=0.9, n_hidden=100, word_vec_len=10, pos_vec_len=10):
|
||||
eta=0.01, mu=0.9, n_hidden=100,
|
||||
nv_word=10, nv_tag=10, nv_label=10):
|
||||
dep_model_dir = path.join(model_dir, 'deps')
|
||||
pos_model_dir = path.join(model_dir, 'pos')
|
||||
ner_model_dir = path.join(model_dir, 'ner')
|
||||
|
@ -99,8 +100,15 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
|||
os.mkdir(ner_model_dir)
|
||||
setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES, pos_model_dir)
|
||||
|
||||
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
||||
labels=Language.ParserTransitionSystem.get_labels(gold_tuples))
|
||||
Config.write(dep_model_dir, 'config',
|
||||
seed=seed,
|
||||
features=feat_set,
|
||||
labels=Language.ParserTransitionSystem.get_labels(gold_tuples),
|
||||
vector_lengths=(nv_word, nv_tag, nv_label),
|
||||
hidden_nodes=n_hidden,
|
||||
eta=eta,
|
||||
mu=mu
|
||||
)
|
||||
Config.write(ner_model_dir, 'config', features='ner', seed=seed,
|
||||
labels=Language.EntityTransitionSystem.get_labels(gold_tuples),
|
||||
beam_width=0)
|
||||
|
@ -110,16 +118,17 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
|||
|
||||
nlp = Language(data_dir=model_dir)
|
||||
|
||||
def make_model(n_classes, input_spec, model_dir):
|
||||
print input_spec
|
||||
n_in = sum(n_cols * len(fields) for (n_cols, fields) in input_spec)
|
||||
def make_model(n_classes, (words, tags, labels), model_dir):
|
||||
n_in = (nv_word * len(words)) + \
|
||||
(nv_tag * len(tags)) + \
|
||||
(nv_label * len(labels))
|
||||
print 'Compiling'
|
||||
debug, train_func, predict_func = compile_theano_model(n_classes, n_hidden,
|
||||
n_in, 0.0, 0.0)
|
||||
print 'Done'
|
||||
return TheanoModel(
|
||||
n_classes,
|
||||
input_spec,
|
||||
((nv_word, words), (nv_tag, tags), (nv_label, labels)),
|
||||
train_func,
|
||||
predict_func,
|
||||
model_loc=model_dir,
|
||||
|
@ -226,14 +235,23 @@ def write_parses(Language, dev_loc, model_dir, out_loc, beam_width=None):
|
|||
n_sents=("Number of training sentences", "option", "n", int),
|
||||
n_iter=("Number of training iterations", "option", "i", int),
|
||||
verbose=("Verbose error reporting", "flag", "v", bool),
|
||||
debug=("Debug mode", "flag", "d", bool),
|
||||
|
||||
nv_word=("Word vector length", "option", "W", int),
|
||||
nv_tag=("Tag vector length", "option", "T", int),
|
||||
nv_label=("Label vector length", "option", "L", int),
|
||||
nv_hidden=("Hidden nodes length", "option", "H", int),
|
||||
eta=("Learning rate", "option", "E", float),
|
||||
mu=("Momentum", "option", "M", float),
|
||||
)
|
||||
def main(train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbose=False,
|
||||
debug=False, corruption_level=0.0, gold_preproc=False, beam_width=1,
|
||||
corruption_level=0.0, gold_preproc=False,
|
||||
nv_word=10, nv_tag=10, nv_label=10, nv_hidden=10,
|
||||
eta=0.1, mu=0.9,
|
||||
eval_only=False):
|
||||
gold_train = list(read_json_file(train_loc))
|
||||
nlp = train(English, gold_train, model_dir,
|
||||
feat_set='embed',
|
||||
nv_word=nv_word, nv_tag=nv_tag, nv_label=nv_label,
|
||||
gold_preproc=gold_preproc, n_sents=n_sents,
|
||||
corruption_level=corruption_level, n_iter=n_iter,
|
||||
verbose=verbose)
|
||||
|
|
Loading…
Reference in New Issue
Block a user