diff --git a/bin/parser/train.py b/bin/parser/train.py index 89d9b1f4c..a89316ef1 100755 --- a/bin/parser/train.py +++ b/bin/parser/train.py @@ -21,6 +21,7 @@ from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir from spacy.syntax.parser import GreedyParser from spacy.syntax.parser import OracleError from spacy.syntax.util import Config +from spacy.syntax.conll import GoldParse def is_punct_label(label): @@ -184,6 +185,7 @@ def get_labels(sents): def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0, gold_preproc=False, force_gold=False): + print "Setup model dir" dep_model_dir = path.join(model_dir, 'deps') pos_model_dir = path.join(model_dir, 'pos') if path.exists(dep_model_dir): @@ -198,7 +200,6 @@ def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0, labels = Language.ParserTransitionSystem.get_labels(gold_sents) Config.write(dep_model_dir, 'config', features=feat_set, seed=seed, labels=labels) - nlp = Language() for itn in range(n_iter): @@ -206,16 +207,16 @@ def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0, pos_corr = 0 n_tokens = 0 for gold_sent in gold_sents: - tokens = nlp.tokenizer(gold_sent.raw) - gold_sent.align_to_tokens(tokens) + tokens = nlp.tokenizer(gold_sent.raw_text) + gold_sent.align_to_tokens(tokens, nlp.parser.moves.label_ids) nlp.tagger(tokens) heads_corr += nlp.parser.train(tokens, gold_sent, force_gold=force_gold) - pos_corr += nlp.tagger.train(tokens, gold_parse.tags) + pos_corr += nlp.tagger.train(tokens, gold_sent.tags) n_tokens += len(tokens) acc = float(heads_corr) / n_tokens pos_acc = float(pos_corr) / n_tokens print '%d: ' % itn, '%.3f' % acc, '%.3f' % pos_acc - random.shuffle(paragraphs) + random.shuffle(gold_sents) nlp.parser.model.end_training() nlp.tagger.model.end_training() return acc @@ -257,10 +258,16 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False): return float(uas_corr) / (total + loss) -def main(train_loc, dev_loc, model_dir): - train(English, read_docparse_gold(train_loc), model_dir, +@plac.annotations( + train_loc=("Training file location",), + dev_loc=("Dev. file location",), + model_dir=("Location of output model directory",), + n_sents=("Number of training sentences", "option", "n", int) +) +def main(train_loc, dev_loc, model_dir, n_sents=0): + train(English, read_gold(train_loc, n=n_sents), model_dir, gold_preproc=False, force_gold=False) - print evaluate(English, read_docparse_gold(dev_loc), model_dir, gold_preproc=False) + print evaluate(English, read_gold(dev_loc), model_dir, gold_preproc=False) if __name__ == '__main__':