Connect parser L1 option to train CLI

This commit is contained in:
Matthew Honnibal 2017-03-26 07:24:07 -05:00
parent ed2b106f4d
commit 6b7f7a2060
2 changed files with 7 additions and 2 deletions

View File

@ -82,16 +82,19 @@ class CLI(object):
train_data=("training data", "positional", None, str), train_data=("training data", "positional", None, str),
dev_data=("development data", "positional", None, str), dev_data=("development data", "positional", None, str),
n_iter=("number of iterations", "option", "n", int), n_iter=("number of iterations", "option", "n", int),
parser_L1=("L1 regularization penalty for parser", "option", "L", float),
no_tagger=("Don't train tagger", "flag", "T", bool), no_tagger=("Don't train tagger", "flag", "T", bool),
no_parser=("Don't train parser", "flag", "P", bool), no_parser=("Don't train parser", "flag", "P", bool),
no_ner=("Don't train NER", "flag", "N", bool) no_ner=("Don't train NER", "flag", "N", bool)
) )
def train(self, lang, output_dir, train_data, dev_data, n_iter=15, def train(self, lang, output_dir, train_data, dev_data, n_iter=15,
parser_L1=0.0,
no_tagger=False, no_parser=False, no_ner=False): no_tagger=False, no_parser=False, no_ner=False):
"""Train a model.""" """Train a model."""
cli_train(lang, output_dir, train_data, dev_data, n_iter, cli_train(lang, output_dir, train_data, dev_data, n_iter,
not no_tagger, not no_parser, not no_ner) not no_tagger, not no_parser, not no_ner,
parser_L1)
@plac.annotations( @plac.annotations(

View File

@ -13,7 +13,8 @@ from ..gold import read_json_file as read_gold_json
from .. import util from .. import util
def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ner): def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ner,
parser_L1):
output_path = Path(output_dir) output_path = Path(output_dir)
train_path = Path(train_data) train_path = Path(train_data)
dev_path = Path(dev_data) dev_path = Path(dev_data)
@ -22,6 +23,7 @@ def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ne
lang = util.get_lang_class(language) lang = util.get_lang_class(language)
parser_cfg = { parser_cfg = {
'pseudoprojective': True, 'pseudoprojective': True,
'L1': parser_L1,
'n_iter': n_iter, 'n_iter': n_iter,
'lang': language, 'lang': language,
'features': lang.Defaults.parser_features} 'features': lang.Defaults.parser_features}