diff --git a/bin/parser/train.py b/bin/parser/train.py index a89316ef1..998c74819 100755 --- a/bin/parser/train.py +++ b/bin/parser/train.py @@ -21,7 +21,7 @@ from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir from spacy.syntax.parser import GreedyParser from spacy.syntax.parser import OracleError from spacy.syntax.util import Config -from spacy.syntax.conll import GoldParse +from spacy.syntax.conll import GoldParse, is_punct_label def is_punct_label(label): @@ -206,15 +206,22 @@ def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0, heads_corr = 0 pos_corr = 0 n_tokens = 0 + n_all_tokens = 0 for gold_sent in gold_sents: - tokens = nlp.tokenizer(gold_sent.raw_text) - gold_sent.align_to_tokens(tokens, nlp.parser.moves.label_ids) + if gold_preproc: + #print ' '.join(gold_sent.words) + tokens = nlp.tokenizer.tokens_from_list(gold_sent.words) + gold_sent.map_heads(nlp.parser.moves.label_ids) + else: + tokens = nlp.tokenizer(gold_sent.raw_text) + gold_sent.align_to_tokens(tokens, nlp.parser.moves.label_ids) nlp.tagger(tokens) heads_corr += nlp.parser.train(tokens, gold_sent, force_gold=force_gold) pos_corr += nlp.tagger.train(tokens, gold_sent.tags) - n_tokens += len(tokens) + n_tokens += gold_sent.n_non_punct + n_all_tokens += len(tokens) acc = float(heads_corr) / n_tokens - pos_acc = float(pos_corr) / n_tokens + pos_acc = float(pos_corr) / n_all_tokens print '%d: ' % itn, '%.3f' % acc, '%.3f' % pos_acc random.shuffle(gold_sents) nlp.parser.model.end_training() @@ -241,21 +248,26 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False): nlp.tagger(tokens) nlp.parser(tokens) for i, token in enumerate(tokens): - pos_corr += token.tag_ == tag_strs[i] + pos_corr += token.tag_ == gold_sent.tags[i] n_tokens += 1 - if heads[i] is None: + if gold_sent.heads[i] is None: skipped += 1 continue - if is_punct_label(labels[i]): - continue - uas_corr += token.head.i == heads[i] - las_corr += token.head.i == heads[i] and token.dep_ == labels[i] - #print token.orth_, token.head.orth_, token.dep_, labels[i] - total += 1 + #print i, token.orth_, token.head.i, gold_sent.py_heads[i], gold_sent.labels[i], + #print gold_sent.is_correct(i, token.head.i) + if gold_sent.labels[i] != 'P': + n_corr += gold_sent.is_correct(i, token.head.i) + total += 1 print loss, skipped, (loss+skipped + total) print pos_corr / n_tokens - print float(las_corr) / (total + loss) - return float(uas_corr) / (total + loss) + return float(n_corr) / (total + loss) + + +def read_gold(loc, n=0): + sent_strs = open(loc).read().strip().split('\n\n') + if n == 0: + n = len(sent_strs) + return [GoldParse.from_docparse(sent) for sent in sent_strs[:n]] @plac.annotations( @@ -265,8 +277,8 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False): n_sents=("Number of training sentences", "option", "n", int) ) def main(train_loc, dev_loc, model_dir, n_sents=0): - train(English, read_gold(train_loc, n=n_sents), model_dir, - gold_preproc=False, force_gold=False) + #train(English, read_gold(train_loc, n=n_sents), model_dir, + # gold_preproc=False, force_gold=False) print evaluate(English, read_gold(dev_loc), model_dir, gold_preproc=False)