mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-12 13:14:18 +03:00
Fix bin/parser/train
This commit is contained in:
parent
cf2131d649
commit
7c7a05a466
|
@ -52,18 +52,6 @@ def add_noise(orig, noise_level):
|
|||
return ''.join(_corrupt(c, noise_level) for c in orig)
|
||||
|
||||
|
||||
def score_model(scorer, nlp, raw_text, annot_tuples, verbose=False):
|
||||
if raw_text is None:
|
||||
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
||||
else:
|
||||
tokens = nlp.tokenizer(raw_text)
|
||||
nlp.tagger(tokens)
|
||||
nlp.entity(tokens)
|
||||
nlp.parser(tokens)
|
||||
gold = GoldParse(tokens, annot_tuples)
|
||||
scorer.score(tokens, gold, verbose=verbose)
|
||||
|
||||
|
||||
def _merge_sents(sents):
|
||||
m_deps = [[], [], [], [], [], []]
|
||||
m_brackets = []
|
||||
|
@ -80,7 +68,7 @@ def _merge_sents(sents):
|
|||
return [(m_deps, m_brackets)]
|
||||
|
||||
|
||||
def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
||||
def train(Language, gold_tuples, model_dir, dev_loc, n_iter=15, feat_set=u'basic',
|
||||
seed=0, gold_preproc=False, n_sents=0, corruption_level=0,
|
||||
beam_width=1, verbose=False,
|
||||
use_orig_arc_eager=False, pseudoprojective=False):
|
||||
|
@ -101,8 +89,9 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
|||
# preprocess training data here before ArcEager.get_labels() is called
|
||||
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
|
||||
|
||||
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
||||
Config.write(dep_model_dir, 'config', feat_set=feat_set, seed=seed,
|
||||
labels=ArcEager.get_labels(gold_tuples),
|
||||
rho=0.0, eta=1.0, mu=0.9, noise=0.0,
|
||||
beam_width=beam_width,projectivize=pseudoprojective)
|
||||
#feat_set, slots = get_templates('neural')
|
||||
#vector_widths = [10, 10, 10]
|
||||
|
@ -121,56 +110,108 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
|||
# rho=rho)
|
||||
|
||||
|
||||
Config.write(ner_model_dir, 'config', features='ner', seed=seed,
|
||||
Config.write(ner_model_dir, 'config', feat_set='ner', seed=seed,
|
||||
labels=BiluoPushDown.get_labels(gold_tuples),
|
||||
beam_width=0)
|
||||
beam_width=0, rho=0.0, eta=1.0, mu=0.9, noise=0.0)
|
||||
|
||||
if n_sents > 0:
|
||||
gold_tuples = gold_tuples[:n_sents]
|
||||
|
||||
micro_eval = gold_tuples[:50]
|
||||
nlp = Language(data_dir=model_dir, tagger=False, parser=False, entity=False)
|
||||
nlp.tagger = Tagger.blank(nlp.vocab, Tagger.default_templates())
|
||||
nlp.parser = BeamParser.from_dir(dep_model_dir, nlp.vocab.strings, ArcEager)
|
||||
nlp.entity = BeamParser.from_dir(ner_model_dir, nlp.vocab.strings, BiluoPushDown)
|
||||
nlp.parser = Parser.from_dir(dep_model_dir, nlp.vocab.strings, ArcEager)
|
||||
nlp.entity = Parser.from_dir(ner_model_dir, nlp.vocab.strings, BiluoPushDown)
|
||||
print(nlp.parser.model.widths)
|
||||
for raw_text, sents in gold_tuples:
|
||||
for annot_tuples, ctnt in sents:
|
||||
for word in annot_tuples[1]:
|
||||
_ = nlp.vocab[word]
|
||||
eg_seen = 0
|
||||
print("Itn.\tP.Loss\tUAS\tNER F.\tTag %\tToken %")
|
||||
for itn in range(n_iter):
|
||||
scorer = Scorer()
|
||||
loss = 0
|
||||
for raw_text, sents in gold_tuples:
|
||||
if gold_preproc:
|
||||
raw_text = None
|
||||
else:
|
||||
sents = _merge_sents(sents)
|
||||
for annot_tuples, ctnt in sents:
|
||||
if len(annot_tuples[1]) == 1:
|
||||
continue
|
||||
score_model(scorer, nlp, raw_text, annot_tuples,
|
||||
verbose=verbose if itn >= 2 else False)
|
||||
if raw_text is None:
|
||||
words = add_noise(annot_tuples[1], corruption_level)
|
||||
tokens = nlp.tokenizer.tokens_from_list(words)
|
||||
else:
|
||||
raw_text = add_noise(raw_text, corruption_level)
|
||||
tokens = nlp.tokenizer(raw_text)
|
||||
nlp.tagger(tokens)
|
||||
gold = GoldParse(tokens, annot_tuples)
|
||||
if not gold.is_projective:
|
||||
raise Exception("Non-projective sentence in training: %s" % annot_tuples[1])
|
||||
loss += nlp.parser.train(tokens, gold)
|
||||
nlp.entity.train(tokens, gold)
|
||||
nlp.tagger.train(tokens, gold.tags)
|
||||
random.shuffle(gold_tuples)
|
||||
print('%d:\t%d\t%.3f\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas, scorer.ents_f,
|
||||
scorer.tags_acc,
|
||||
scorer.token_acc))
|
||||
print('end training')
|
||||
try:
|
||||
eg_seen = _train_epoch(nlp, gold_tuples, eg_seen, itn,
|
||||
dev_loc, micro_eval,
|
||||
gold_preproc, corruption_level)
|
||||
except KeyboardInterrupt:
|
||||
print("Saving model...")
|
||||
break
|
||||
dev_uas = score_file(nlp, dev_loc).uas
|
||||
print("Dev before average", dev_uas)
|
||||
nlp.end_training(model_dir)
|
||||
print('done')
|
||||
print("Saved. Evaluating...")
|
||||
|
||||
|
||||
def _train_epoch(nlp, gold_tuples, eg_seen, itn, dev_loc, micro_eval,
|
||||
gold_preproc, corruption_level):
|
||||
random.shuffle(gold_tuples)
|
||||
loss = 0
|
||||
nr_trimmed = 0
|
||||
for raw_text, sents in gold_tuples:
|
||||
if gold_preproc:
|
||||
raw_text = None
|
||||
else:
|
||||
sents = _merge_sents(sents)
|
||||
for annot_tuples, ctnt in sents:
|
||||
if len(annot_tuples[1]) == 1:
|
||||
continue
|
||||
if raw_text is None:
|
||||
words = add_noise(annot_tuples[1], corruption_level)
|
||||
tokens = nlp.tokenizer.tokens_from_list(words)
|
||||
else:
|
||||
raw_text = add_noise(raw_text, corruption_level)
|
||||
tokens = nlp.tokenizer(raw_text)
|
||||
nlp.tagger(tokens)
|
||||
gold = GoldParse(tokens, annot_tuples)
|
||||
if not gold.is_projective:
|
||||
raise Exception("Non-projective sentence in training: %s" % annot_tuples[1])
|
||||
loss += nlp.parser.train(tokens, gold)
|
||||
nlp.entity.train(tokens, gold)
|
||||
nlp.tagger.train(tokens, gold.tags)
|
||||
|
||||
eg_seen += 1
|
||||
if eg_seen % 1000 == 0:
|
||||
scorer = score_sents(nlp, micro_eval)
|
||||
print('%d:\t%d\t%.3f\t%.3f\t%.3f\t%.3f\t%d\t%d' % (itn, loss, scorer.uas, scorer.ents_f,
|
||||
scorer.tags_acc,
|
||||
scorer.token_acc,
|
||||
nlp.parser.model.nr_active_feat,
|
||||
nlp.entity.model.nr_active_feat))
|
||||
loss = 0
|
||||
nlp.parser.model.learn_rate *= 0.99
|
||||
scorer = score_file(nlp, dev_loc)
|
||||
print('D:\t%d\t%.3f\t%.3f\t%.3f\t%.3f' % (loss, scorer.uas, scorer.ents_f,
|
||||
scorer.tags_acc, scorer.token_acc))
|
||||
return eg_seen
|
||||
|
||||
|
||||
def score_file(nlp, loc):
|
||||
gold_sents = read_json_file(loc, verbose=False)
|
||||
scorer = Scorer()
|
||||
for _, sents in gold_sents:
|
||||
for annot_tuples, _ in sents:
|
||||
score_model(scorer, nlp, None, annot_tuples)
|
||||
return scorer
|
||||
|
||||
|
||||
def score_sents(nlp, gold_tuples):
|
||||
scorer = Scorer()
|
||||
for _, sents in gold_tuples:
|
||||
for annot_tuples, _ in sents:
|
||||
score_model(scorer, nlp, None, annot_tuples)
|
||||
return scorer
|
||||
|
||||
|
||||
def score_model(scorer, nlp, raw_text, annot_tuples, verbose=False):
|
||||
if raw_text is None:
|
||||
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
||||
else:
|
||||
tokens = nlp.tokenizer(raw_text)
|
||||
nlp.tagger(tokens)
|
||||
nlp.entity(tokens)
|
||||
nlp.parser(tokens)
|
||||
gold = GoldParse(tokens, annot_tuples)
|
||||
scorer.score(tokens, gold, verbose=verbose)
|
||||
|
||||
|
||||
def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=False,
|
||||
|
@ -201,7 +242,7 @@ def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=False
|
|||
|
||||
def write_parses(Language, dev_loc, model_dir, out_loc):
|
||||
nlp = Language(data_dir=model_dir)
|
||||
gold_tuples = read_json_file(dev_loc)
|
||||
gold_tuples = read_json_file(dev_loc, verbose=True)
|
||||
scorer = Scorer()
|
||||
out_file = io.open(out_loc, 'w', 'utf8')
|
||||
for raw_text, sents in gold_tuples:
|
||||
|
@ -245,16 +286,16 @@ def main(language, train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc=
|
|||
lang = spacy.util.get_lang_class(language)
|
||||
|
||||
if not eval_only:
|
||||
gold_train = list(read_json_file(train_loc))
|
||||
train(lang, gold_train, model_dir,
|
||||
feat_set='neural' if not debug else 'debug',
|
||||
gold_train = list(read_json_file(train_loc, verbose=True))
|
||||
train(lang, gold_train, model_dir, dev_loc,
|
||||
feat_set='basic', #'neural' if not debug else 'debug',
|
||||
gold_preproc=gold_preproc, n_sents=n_sents,
|
||||
corruption_level=corruption_level, n_iter=n_iter,
|
||||
verbose=verbose,pseudoprojective=pseudoprojective)
|
||||
if out_loc:
|
||||
write_parses(lang, dev_loc, model_dir, out_loc)
|
||||
print(model_dir)
|
||||
scorer = evaluate(lang, list(read_json_file(dev_loc)),
|
||||
scorer = evaluate(lang, list(read_json_file(dev_loc, verbose=True)),
|
||||
model_dir, gold_preproc=gold_preproc, verbose=verbose)
|
||||
print('TOK', scorer.token_acc)
|
||||
print('POS', scorer.tags_acc)
|
||||
|
|
Loading…
Reference in New Issue
Block a user