From e167355505cbcd1aba8b9a05513ff9ccb8f26f72 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 6 May 2015 16:38:54 +0200 Subject: [PATCH] * Use JSON docs for training and evaluation. Currently a bug that is costing 0.6 acc --- bin/parser/train.py | 44 ++++++++++++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/bin/parser/train.py b/bin/parser/train.py index 9ae3a3267..922e245ea 100755 --- a/bin/parser/train.py +++ b/bin/parser/train.py @@ -19,13 +19,13 @@ from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir from spacy.syntax.parser import GreedyParser from spacy.syntax.parser import OracleError from spacy.syntax.util import Config -from spacy.syntax.conll import read_docparse_file +from spacy.syntax.conll import read_docparse_file, read_json_file from spacy.syntax.conll import GoldParse from spacy.scorer import Scorer -def train(Language, train_loc, model_dir, n_iter=15, feat_set=u'basic', seed=0, +def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0, gold_preproc=False, n_sents=0): dep_model_dir = path.join(model_dir, 'deps') pos_model_dir = path.join(model_dir, 'pos') @@ -42,8 +42,6 @@ def train(Language, train_loc, model_dir, n_iter=15, feat_set=u'basic', seed=0, setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES, pos_model_dir) - gold_tuples = read_docparse_file(train_loc) - Config.write(dep_model_dir, 'config', features=feat_set, seed=seed, labels=Language.ParserTransitionSystem.get_labels(gold_tuples)) Config.write(ner_model_dir, 'config', features='ner', seed=seed, @@ -56,9 +54,12 @@ def train(Language, train_loc, model_dir, n_iter=15, feat_set=u'basic', seed=0, print "Itn.\tUAS\tNER F.\tTag %" for itn in range(n_iter): scorer = Scorer() - for raw_text, segmented_text, annot_tuples in gold_tuples: + for raw_text, segmented_text, annot_tuples, ctnt in gold_tuples: # Eval before train tokens = nlp(raw_text, merge_mwes=False) + #print segmented_text + #for annot in zip(*annot_tuples): + # print annot gold = GoldParse(tokens, annot_tuples) scorer.score(tokens, gold, verbose=False) @@ -75,19 +76,18 @@ def train(Language, train_loc, model_dir, n_iter=15, feat_set=u'basic', seed=0, nlp.tagger.train(tokens, gold.tags) print '%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.ents_f, scorer.tags_acc) - random.shuffle(gold_tuples) + #random.shuffle(gold_tuples) nlp.parser.model.end_training() nlp.entity.model.end_training() nlp.tagger.model.end_training() nlp.vocab.strings.dump(path.join(model_dir, 'vocab', 'strings.txt')) -def evaluate(Language, dev_loc, model_dir, gold_preproc=False, verbose=True): +def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=True): assert not gold_preproc nlp = Language(data_dir=model_dir) - gold_tuples = read_docparse_file(dev_loc) scorer = Scorer() - for raw_text, segmented_text, annot_tuples in gold_tuples: + for raw_text, segmented_text, annot_tuples, brackets in gold_tuples: tokens = nlp(raw_text, merge_mwes=False) gold = GoldParse(tokens, annot_tuples) scorer.score(tokens, gold, verbose=verbose) @@ -108,22 +108,38 @@ def write_parses(Language, dev_loc, model_dir, out_loc): return scorer +def get_sents(json_dir, section): + if section == 'train': + file_range = range(2, 22) + elif section == 'dev': + file_range = range(22, 23) + + for i in file_range: + sec = str(i) + if len(sec) == 1: + sec = '0' + sec + loc = path.join(json_dir, sec + '.json') + for sent in read_json_file(loc): + yield sent + + @plac.annotations( - train_loc=("Training file location",), - dev_loc=("Dev. file location",), + json_dir=("Annotated JSON files directory",), model_dir=("Location of output model directory",), out_loc=("Out location", "option", "o", str), n_sents=("Number of training sentences", "option", "n", int), verbose=("Verbose error reporting", "flag", "v", bool), debug=("Debug mode", "flag", "d", bool) ) -def main(train_loc, dev_loc, model_dir, n_sents=0, out_loc="", verbose=False, +def main(json_dir, model_dir, n_sents=0, out_loc="", verbose=False, debug=False): - train(English, train_loc, model_dir, feat_set='basic' if not debug else 'debug', + train(English, list(get_sents(json_dir, 'train')), model_dir, + feat_set='basic' if not debug else 'debug', gold_preproc=False, n_sents=n_sents) if out_loc: write_parses(English, dev_loc, model_dir, out_loc) - scorer = evaluate(English, dev_loc, model_dir, gold_preproc=False, verbose=verbose) + scorer = evaluate(English, list(get_sents(json_dir, 'dev')), + model_dir, gold_preproc=False, verbose=verbose) print 'TOK', scorer.mistokened print 'POS', scorer.tags_acc print 'UAS', scorer.uas