mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-11 08:42:28 +03:00
Update conll_train script
This commit is contained in:
parent
136a7a2322
commit
0c7520dbb7
|
@ -119,7 +119,8 @@ def score_sents(nlp, gold_tuples):
|
||||||
|
|
||||||
|
|
||||||
def train(Language, gold_tuples, model_dir, dev_loc, n_iter=15, feat_set=u'basic',
|
def train(Language, gold_tuples, model_dir, dev_loc, n_iter=15, feat_set=u'basic',
|
||||||
learn_rate=0.001, noise=0.01, update_step='sgd_cm',
|
width=128, depth=3,
|
||||||
|
learn_rate=0.001, noise=0.01, update_step='sgd_cm', regularization=0.0,
|
||||||
batch_norm=False, seed=0, gold_preproc=False, force_gold=False):
|
batch_norm=False, seed=0, gold_preproc=False, force_gold=False):
|
||||||
dep_model_dir = path.join(model_dir, 'deps')
|
dep_model_dir = path.join(model_dir, 'deps')
|
||||||
pos_model_dir = path.join(model_dir, 'pos')
|
pos_model_dir = path.join(model_dir, 'pos')
|
||||||
|
@ -132,11 +133,11 @@ def train(Language, gold_tuples, model_dir, dev_loc, n_iter=15, feat_set=u'basic
|
||||||
|
|
||||||
if feat_set != 'neural':
|
if feat_set != 'neural':
|
||||||
Config.write(dep_model_dir, 'config', feat_set=feat_set, seed=seed,
|
Config.write(dep_model_dir, 'config', feat_set=feat_set, seed=seed,
|
||||||
labels=ArcEager.get_labels(gold_tuples))
|
labels=ArcEager.get_labels(gold_tuples),
|
||||||
|
eta=learn_rate, rho=regularization)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
hidden_layers = [128] * 3
|
hidden_layers = [width] * depth
|
||||||
rho = 1e-4
|
|
||||||
Config.write(dep_model_dir, 'config',
|
Config.write(dep_model_dir, 'config',
|
||||||
model='neural',
|
model='neural',
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
@ -148,18 +149,18 @@ def train(Language, gold_tuples, model_dir, dev_loc, n_iter=15, feat_set=u'basic
|
||||||
eta=learn_rate,
|
eta=learn_rate,
|
||||||
mu=0.9,
|
mu=0.9,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
ensemble_size=1,
|
rho=regularization)
|
||||||
rho=rho)
|
|
||||||
|
|
||||||
nlp = Language(data_dir=model_dir, tagger=False, parser=False, entity=False)
|
nlp = Language(data_dir=model_dir, tagger=False, parser=False, entity=False)
|
||||||
nlp.tagger = Tagger.blank(nlp.vocab, Tagger.default_templates())
|
nlp.tagger = Tagger.blank(nlp.vocab, Tagger.default_templates())
|
||||||
nlp.parser = BeamParser.from_dir(dep_model_dir, nlp.vocab.strings, ArcEager)
|
#nlp.parser = BeamParser.from_dir(dep_model_dir, nlp.vocab.strings, ArcEager)
|
||||||
|
nlp.parser = Parser.from_dir(dep_model_dir, nlp.vocab.strings, ArcEager)
|
||||||
for word in nlp.vocab:
|
for word in nlp.vocab:
|
||||||
word.norm = word.orth
|
word.norm = word.orth
|
||||||
|
|
||||||
print(nlp.parser.model.widths)
|
print(nlp.parser.model.widths)
|
||||||
|
|
||||||
print("Itn.\tP.Loss\tPruned\tTrain\tDev\tSize")
|
print("Itn.\tP.Loss\tTrain\tDev\tnr_weight")
|
||||||
last_score = 0.0
|
last_score = 0.0
|
||||||
nr_trimmed = 0
|
nr_trimmed = 0
|
||||||
eg_seen = 0
|
eg_seen = 0
|
||||||
|
@ -197,9 +198,9 @@ def _train_epoch(nlp, gold_tuples, eg_seen, itn, dev_loc, micro_eval):
|
||||||
else:
|
else:
|
||||||
dev_uas = 0.0
|
dev_uas = 0.0
|
||||||
train_uas = score_sents(nlp, micro_eval).uas
|
train_uas = score_sents(nlp, micro_eval).uas
|
||||||
size = nlp.parser.model.mem.size
|
size = nlp.parser.model.nr_weight
|
||||||
nr_upd = nlp.parser.model.time
|
nr_upd = nlp.parser.model.time
|
||||||
print('%d,%d:\t%d\t%.3f\t%.3f\t%.3f\t%d' % (itn, nr_upd, int(loss), nr_trimmed,
|
print('%d,%d:\t%d\t%.3f\t%.3f\t%d' % (itn, nr_upd, int(loss),
|
||||||
train_uas, dev_uas, size))
|
train_uas, dev_uas, size))
|
||||||
loss = 0
|
loss = 0
|
||||||
return eg_seen
|
return eg_seen
|
||||||
|
@ -213,20 +214,26 @@ def _train_epoch(nlp, gold_tuples, eg_seen, itn, dev_loc, micro_eval):
|
||||||
batch_norm=("Use batch normalization and residual connections", "flag", "b"),
|
batch_norm=("Use batch normalization and residual connections", "flag", "b"),
|
||||||
update_step=("Update step", "option", "u", str),
|
update_step=("Update step", "option", "u", str),
|
||||||
learn_rate=("Learn rate", "option", "e", float),
|
learn_rate=("Learn rate", "option", "e", float),
|
||||||
gradient_noise=("Gradient noise", "option", "w", float),
|
regularization=("Regularization penalty", "option", "r", float),
|
||||||
neural=("Use neural network?", "flag", "N")
|
gradient_noise=("Gradient noise", "option", "W", float),
|
||||||
|
neural=("Use neural network?", "flag", "N"),
|
||||||
|
width=("Width of hidden layers", "option", "w", int),
|
||||||
|
depth=("Number of hidden layers", "option", "d", int),
|
||||||
)
|
)
|
||||||
def main(train_loc, dev_loc, model_dir, n_iter=15, neural=False, batch_norm=False,
|
def main(train_loc, dev_loc, model_dir, n_iter=15, neural=False, batch_norm=False,
|
||||||
learn_rate=0.001, gradient_noise=0.0, update_step='sgd_cm'):
|
width=128, depth=3, learn_rate=0.001, gradient_noise=0.0, regularization=0.0,
|
||||||
|
update_step='sgd_cm'):
|
||||||
with io.open(train_loc, 'r', encoding='utf8') as file_:
|
with io.open(train_loc, 'r', encoding='utf8') as file_:
|
||||||
train_sents = list(read_conll(file_))
|
train_sents = list(read_conll(file_))
|
||||||
# Preprocess training data here before ArcEager.get_labels() is called
|
# Preprocess training data here before ArcEager.get_labels() is called
|
||||||
train_sents = PseudoProjectivity.preprocess_training_data(train_sents)
|
train_sents = PseudoProjectivity.preprocess_training_data(train_sents)
|
||||||
|
|
||||||
nlp = train(English, train_sents, model_dir, dev_loc, n_iter=n_iter,
|
nlp = train(English, train_sents, model_dir, dev_loc, n_iter=n_iter,
|
||||||
|
width=width, depth=depth,
|
||||||
feat_set='neural' if neural else 'basic',
|
feat_set='neural' if neural else 'basic',
|
||||||
batch_norm=batch_norm,
|
batch_norm=batch_norm,
|
||||||
learn_rate=learn_rate,
|
learn_rate=learn_rate,
|
||||||
|
regularization=regularization,
|
||||||
update_step=update_step,
|
update_step=update_step,
|
||||||
noise=gradient_noise)
|
noise=gradient_noise)
|
||||||
|
|
||||||
|
@ -237,6 +244,5 @@ def main(train_loc, dev_loc, model_dir, n_iter=15, neural=False, batch_norm=Fals
|
||||||
print('LAS', scorer.las)
|
print('LAS', scorer.las)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
plac.call(main)
|
plac.call(main)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user