diff --git a/bin/parser/train.py b/bin/parser/train.py index 3312a0cf3..7670ca81a 100755 --- a/bin/parser/train.py +++ b/bin/parser/train.py @@ -19,6 +19,7 @@ from spacy.en import English from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir from spacy.syntax.parser import GreedyParser +from spacy.syntax.parser import OracleError from spacy.syntax.util import Config @@ -36,7 +37,7 @@ def read_tokenized_gold(file_): labels = [] tags = [] for i, line in enumerate(sent_str.split('\n')): - word, pos_string, head_idx, label = _parse_line(line) + id_, word, pos_string, head_idx, label = _parse_line(line) words.append(word) if head_idx == -1: head_idx = i @@ -44,22 +45,29 @@ def read_tokenized_gold(file_): heads.append(head_idx) labels.append(label) tags.append(pos_string) - sents.append((ids_, words, heads, labels, tags)) + text = ' '.join(words) + sents.append((text, [words], ids, words, tags, heads, labels)) return sents def read_docparse_gold(file_): paragraphs = [] - for sent_str in file_.read().strip().split('\n\n'): + for sent_str in file_.read().strip().split(''): + if not sent_str.strip(): + continue words = [] heads = [] labels = [] tags = [] ids = [] + try: + raw_text, sent_str = sent_str.strip().split('', 1) + except: + print sent_str + raise lines = sent_str.strip().split('\n') - raw_text = lines[0] - tok_text = lines[1] - for i, line in enumerate(lines[2:]): + tok_text = lines.pop(0) + for i, line in enumerate(lines): id_, word, pos_string, head_idx, label = _parse_line(line) if label == 'root': label = 'ROOT' @@ -180,7 +188,7 @@ def get_labels(sents): def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0, - gold_preproc=False): + gold_preproc=False, force_gold=False): dep_model_dir = path.join(model_dir, 'deps') pos_model_dir = path.join(model_dir, 'pos') if path.exists(dep_model_dir): @@ -205,7 +213,10 @@ def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0, for tokens, tag_strs, heads, labels in iter_data(paragraphs, nlp.tokenizer, gold_preproc=gold_preproc): nlp.tagger(tokens) - heads_corr += nlp.parser.train_sent(tokens, heads, labels, force_gold=False) + try: + heads_corr += nlp.parser.train_sent(tokens, heads, labels, force_gold=force_gold) + except OracleError: + continue pos_corr += nlp.tagger.train(tokens, tag_strs) n_tokens += len(tokens) acc = float(heads_corr) / n_tokens @@ -221,10 +232,13 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False): global loss nlp = Language() n_corr = 0 + pos_corr = 0 + n_tokens = 0 total = 0 skipped = 0 loss = 0 with codecs.open(dev_loc, 'r', 'utf8') as file_: + #paragraphs = read_tokenized_gold(file_) paragraphs = read_docparse_gold(file_) for tokens, tag_strs, heads, labels in iter_data(paragraphs, nlp.tokenizer, gold_preproc=gold_preproc): @@ -232,6 +246,8 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False): nlp.tagger(tokens) nlp.parser(tokens) for i, token in enumerate(tokens): + pos_corr += token.tag_ == tag_strs[i] + n_tokens += 1 if heads[i] is None: skipped += 1 continue @@ -240,14 +256,16 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False): n_corr += token.head.i == heads[i] total += 1 print loss, skipped, (loss+skipped + total) + print pos_corr / n_tokens return float(n_corr) / (total + loss) def main(train_loc, dev_loc, model_dir): with codecs.open(train_loc, 'r', 'utf8') as file_: - train_sents = read_docparse_gold(file_) - train(English, train_sents, model_dir, gold_preproc=False) - print evaluate(English, dev_loc, model_dir, gold_preproc=False) + #train_sents = read_docparse_gold(file_) + train_sents = read_tokenized_gold(file_) + #train(English, train_sents, model_dir, gold_preproc=True, force_gold=False) + print evaluate(English, dev_loc, model_dir, gold_preproc=True) if __name__ == '__main__':