diff --git a/bin/prepare_treebank.py b/bin/prepare_treebank.py index 34c2de3e6..533f7a0c6 100644 --- a/bin/prepare_treebank.py +++ b/bin/prepare_treebank.py @@ -122,53 +122,10 @@ def get_file_names(section_dir, subsection): return list(sorted(set(filenames))) -def main(onto_dir, raw_dir, out_loc): - # All but WSJ --- we do that separately, as we have the source docs - sections = [ - 'bc/cctv', - 'bc/cnn', - 'bc/msnbc', - 'bc/p2.5_a2e', - 'bc/p2.5_c2e', - 'bc/phoenix', - 'bn/abc', - 'bn/cnn', - 'bn/mnb', - 'bn/nbc', - 'bn/p2.5_a2e', - 'bn/p2.5_c2e', - 'bn/pri', - 'bn/voa', - 'mz/sinorama', - 'nw/dev_09_c2e', - 'nw/p2.5_a2e', - 'nw/p2.5_c2e', - 'nw/xinhua', - 'pt/ot', - 'tc/ch', - 'wb/a2e', - 'wb/c2e', - 'wb/eng', - 'wb/dev_09_c2e', - 'wb/p2.5_a2e', - 'wb/p2.5_c2e', - 'wb/sel' - ] - docs = [] - for section in sections: - section_dir = path.join(onto_dir, 'data', 'english', 'annotations', section) - print section, len(docs) - for subsection in os.listdir(section_dir): - for fn in get_file_names(section_dir, subsection): - ptb = read_file(section_dir, subsection, '%s.parse' % fn) - dep = read_file(section_dir, subsection, '%s.parse.dep' % fn) - ner = read_file(section_dir, subsection, '%s.name' % fn) - if ptb is not None: - doc = format_doc(fn, None, ptb, dep, ner) - if doc is not None: - docs.append(doc) +def read_wsj_with_source(onto_dir, raw_dir): # Now do WSJ, with source alignment onto_dir = path.join(onto_dir, 'data', 'english', 'annotations', 'nw', 'wsj') + docs = {} for i in range(25): section = str(i) if i >= 10 else ('0' + str(i)) raw_loc = path.join(raw_dir, 'wsj%s.json' % section) @@ -181,12 +138,40 @@ def main(onto_dir, raw_dir, out_loc): dep = read_file(onto_dir, section, '%s.parse.dep' % filename) ner = read_file(onto_dir, section, '%s.name' % filename) if ptb is not None and dep is not None: - docs.append(format_doc(filename, raw_paras, ptb, dep, ner)) - print 'nw/wsj', len(docs) - with open(out_loc, 'w') as file_: - json.dump(docs, file_, indent=4) + docs[filename] = format_doc(filename, raw_paras, ptb, dep, ner) + return docs +def get_doc(onto_dir, file_path, wsj_docs): + filename = file_path.rsplit('/', 1)[1] + if filename in wsj_docs: + return wsj_docs[filename] + else: + ptb = read_file(onto_dir, file_path + '.parse') + dep = read_file(onto_dir, file_path + '.parse.dep') + ner = read_file(onto_dir, file_path + '.name') + if ptb is not None and dep is not None: + return format_doc(filename, None, ptb, dep, ner) + else: + return None + +def read_ids(loc): + return open(loc).read().strip().split('\n') + +def main(onto_dir, raw_dir, out_dir): + wsj_docs = read_wsj_with_source(onto_dir, raw_dir) + + for partition in ('train', 'test', 'development'): + ids = read_ids(path.join(onto_dir, '%s.id' % partition)) + out_loc = path.join(out_dir, '%s.json' % partition) + docs = [] + for file_path in ids: + doc = get_doc(onto_dir, file_path, wsj_docs) + if doc is not None: + docs.append(doc) + with open(out_loc, 'w') as file_: + json.dump(docs, file_, indent=4) + if __name__ == '__main__': plac.call(main)