* Fix parser training script

This commit is contained in:
Matthew Honnibal 2015-02-09 03:57:56 -05:00
parent 5c3513583d
commit ee33be31dd

View File

@ -52,7 +52,7 @@ def read_tokenized_gold(file_):
def read_docparse_gold(file_):
paragraphs = []
for sent_str in file_.read().strip().split('<text>'):
for sent_str in file_.read().strip().split('\n\n'):
if not sent_str.strip():
continue
words = []
@ -60,12 +60,8 @@ def read_docparse_gold(file_):
labels = []
tags = []
ids = []
try:
raw_text, sent_str = sent_str.strip().split('</text>', 1)
except:
print sent_str
raise
lines = sent_str.strip().split('\n')
raw_text = lines.pop(0)
tok_text = lines.pop(0)
for i, line in enumerate(lines):
id_, word, pos_string, head_idx, label = _parse_line(line)
@ -238,7 +234,6 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
skipped = 0
loss = 0
with codecs.open(dev_loc, 'r', 'utf8') as file_:
#paragraphs = read_tokenized_gold(file_)
paragraphs = read_docparse_gold(file_)
for tokens, tag_strs, heads, labels in iter_data(paragraphs, nlp.tokenizer,
gold_preproc=gold_preproc):
@ -246,7 +241,11 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
nlp.tagger(tokens)
nlp.parser(tokens)
for i, token in enumerate(tokens):
pos_corr += token.tag_ == tag_strs[i]
try:
pos_corr += token.tag_ == tag_strs[i]
except:
print i, token.orth_, token.tag
raise
n_tokens += 1
if heads[i] is None:
skipped += 1
@ -262,10 +261,9 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
def main(train_loc, dev_loc, model_dir):
with codecs.open(train_loc, 'r', 'utf8') as file_:
#train_sents = read_docparse_gold(file_)
train_sents = read_tokenized_gold(file_)
#train(English, train_sents, model_dir, gold_preproc=True, force_gold=False)
print evaluate(English, dev_loc, model_dir, gold_preproc=True)
train_sents = read_docparse_gold(file_)
train(English, train_sents, model_dir, gold_preproc=False, force_gold=False)
print evaluate(English, dev_loc, model_dir, gold_preproc=False)
if __name__ == '__main__':