diff --git a/bin/prepare_treebank.py b/bin/prepare_treebank.py index 1de2dfdee..0d0e48921 100644 --- a/bin/prepare_treebank.py +++ b/bin/prepare_treebank.py @@ -60,15 +60,12 @@ def format_doc(section, filename, raw_paras, ptb_loc, dep_loc): 'brackets': []} for raw_sent in raw_sents: para['sents'].append(offset) - _, brackets = read_ptb.parse(ptb_sents[i]) - _, annot = read_conll.parse(dep_sents[i]) + _, brackets = read_ptb.parse(ptb_sents[i], strip_bad_periods=True) + _, annot = read_conll.parse(dep_sents[i], strip_bad_periods=True) indices, word_idx, offset = _get_word_indices(raw_sent, 0, offset) for token in annot: - if token['head'] == -1: - head = indices[token['id']] - else: - head = indices[token['head']] + head = indices[token['head']] try: para['tokens'].append({'start': indices[token['id']], 'tag': token['tag'], @@ -80,32 +77,34 @@ def format_doc(section, filename, raw_paras, ptb_loc, dep_loc): print raw_sent raise for label, start, end in brackets: - para['brackets'].append({'label': label, - 'start': indices[start], - 'end': indices[end-1]}) + if start != end: + para['brackets'].append({'label': label, + 'start': indices[start], + 'end': indices[end-1]}) i += 1 doc['paragraphs'].append(para) return doc -def main(onto_dir, raw_dir, out_loc): - docs = [] +def main(onto_dir, raw_dir, out_dir): for i in range(25): section = str(i) if i >= 10 else ('0' + str(i)) raw_loc = path.join(raw_dir, 'wsj%s.json' % section) + docs = [] for j, raw_paras in enumerate(_iter_raw_files(raw_loc)): if section == '00': j += 1 filename = str(j) if j >= 9 else ('0' + str(j)) if section == '04' and filename == '55': continue - ptb_loc = path.join(onto_dir, section, 'wsj_%s%s.parse' % (section, filename)) - dep_loc = ptb_loc + '.dep' + ptb_loc = path.join(onto_dir, section, 'wsj_%s%s.mrg' % (section, filename)) + dep_loc = ptb_loc + '.3.pa.gs.tab' if path.exists(ptb_loc) and path.exists(dep_loc): print ptb_loc doc = format_doc(section, filename, raw_paras, ptb_loc, dep_loc) docs.append(doc) - json.dump(docs, open(out_loc, 'w')) + with open(path.join(out_dir, '%s.json' % section), 'w') as file_: + json.dump(docs, file_) if __name__ == '__main__': diff --git a/spacy/munge/__init__.py b/spacy/munge/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spacy/munge/align_raw.py b/spacy/munge/align_raw.py new file mode 100644 index 000000000..5d3954b11 --- /dev/null +++ b/spacy/munge/align_raw.py @@ -0,0 +1,175 @@ +"""Align the raw sentences from Read et al (2012) to the PTB tokenization, +outputing the format: + +[{ + section: int, + file: string, + paragraphs: [{ + raw: string, + segmented: string, + tokens: [int]}]}] +""" +import plac +from pathlib import Path +import json +from os import path + +from spacy.munge import read_ptb + + +def read_unsegmented(section_loc): + # Arbitrary patches applied to the _raw_ text to promote alignment. + patches = ( + ('. . . .', '...'), + ('....', '...'), + ('Co..', 'Co.'), + ("`", "'"), + ) + + paragraphs = [] + with open(section_loc) as file_: + para = [] + for line in file_: + if line.startswith('['): + line = line.split('|', 1)[1].strip() + for find, replace in patches: + line = line.replace(find, replace) + para.append(line) + else: + paragraphs.append(para) + para = [] + paragraphs.append(para) + return paragraphs + + +def read_ptb_sec(ptb_sec_dir): + ptb_sec_dir = Path(ptb_sec_dir) + files = [] + for loc in ptb_sec_dir.iterdir(): + if not str(loc).endswith('parse') and not str(loc).endswith('mrg'): + continue + with loc.open() as file_: + text = file_.read() + sents = [] + for parse_str in read_ptb.split(text): + words, brackets = read_ptb.parse(parse_str, strip_bad_periods=True) + words = [_reform_ptb_word(word) for word in words] + string = ' '.join(words) + sents.append(string) + files.append(sents) + return files + + +def _reform_ptb_word(tok): + tok = tok.replace("``", '"') + tok = tok.replace("`", "'") + tok = tok.replace("''", '"') + tok = tok.replace('\\', '') + tok = tok.replace('-LCB-', '{') + tok = tok.replace('-RCB-', '}') + tok = tok.replace('-RRB-', ')') + tok = tok.replace('-LRB-', '(') + tok = tok.replace("'T-", "'T") + return tok + + +def get_alignment(raw_by_para, ptb_by_file): + # These are list-of-lists, by paragraph and file respectively. + # Flatten them into a list of (outer_id, inner_id, item) triples + raw_sents = _flatten(raw_by_para) + ptb_sents = _flatten(ptb_by_file) + + assert len(raw_sents) == len(ptb_sents) + + output = [] + for (p_id, p_sent_id, raw), (f_id, f_sent_id, ptb) in zip(raw_sents, ptb_sents): + alignment = align_chars(raw, ptb) + sepped = [] + for i, c in enumerate(ptb): + if alignment[i] is False: + sepped.append('') + else: + sepped.append(c) + output.append((f_id, p_id, f_sent_id, ''.join(sepped))) + return output + + +def _flatten(nested): + flat = [] + for id1, inner in enumerate(nested): + flat.extend((id1, id2, item) for id2, item in enumerate(inner)) + return flat + + +def align_chars(raw, ptb): + i = 0 + j = 0 + + length = len(raw) + alignment = [False for _ in range(len(ptb))] + while i < length: + if raw[i] == ' ' and ptb[j] == ' ': + alignment[j] = True + i += 1 + j += 1 + elif raw[i] == ' ': + i += 1 + elif ptb[j] == ' ': + j += 1 + assert raw[i].lower() == ptb[j].lower(), raw[i:1] + alignment[j] = i + i += 1; j += 1 + return alignment + + +def group_into_files(sents): + last_id = 0 + this = [] + output = [] + for f_id, p_id, s_id, sent in sents: + if f_id != last_id: + output.append(this) + this = [] + this.append((f_id, p_id, s_id, sent)) + last_id = f_id + if this: + output.append(this) + return output + + +def group_into_paras(sents): + last_id = 0 + this = [] + output = [] + for f_id, p_id, s_id, sent in sents: + if p_id != last_id and this: + output.append(this) + this = [] + this.append((sent)) + last_id = p_id + if this: + output.append(this) + return output + + +def get_sections(odc_dir, ptb_dir, out_dir): + for i in range(25): + section = str(i) if i >= 10 else ('0' + str(i)) + odc_loc = path.join(odc_dir, 'wsj%s.txt' % section) + ptb_sec = path.join(ptb_dir, section) + out_loc = path.join(out_dir, 'wsj%s.json' % section) + yield odc_loc, ptb_sec, out_loc + + +def main(odc_dir, ptb_dir, out_dir): + for odc_loc, ptb_sec_dir, out_loc in get_sections(odc_dir, ptb_dir, out_dir): + raw_paragraphs = read_unsegmented(odc_loc) + ptb_files = read_ptb_sec(ptb_sec_dir) + aligned = get_alignment(raw_paragraphs, ptb_files) + files = [group_into_paras(f) for f in group_into_files(aligned)] + with open(out_loc, 'w') as file_: + json.dump(files, file_) + + +if __name__ == '__main__': + plac.call(main) diff --git a/spacy/munge/read_conll.py b/spacy/munge/read_conll.py new file mode 100644 index 000000000..6b563c1b7 --- /dev/null +++ b/spacy/munge/read_conll.py @@ -0,0 +1,40 @@ +from __future__ import unicode_literals + + +def split(text): + return [sent.strip() for sent in text.split('\n\n') if sent.strip()] + + +def parse(sent_text, strip_bad_periods=False): + sent_text = sent_text.strip() + assert sent_text + annot = [] + words = [] + i = 0 + for line in sent_text.split('\n'): + word, tag, head, dep = line.split() + if strip_bad_periods and words and _is_bad_period(words[-1], word): + continue + + annot.append({ + 'id': i, + 'word': word, + 'tag': tag, + 'head': int(head) - 1 if int(head) != 0 else i, + 'dep': dep}) + words.append(word) + i += 1 + return words, annot + + +def _is_bad_period(prev, period): + if period != '.': + return False + elif prev == '.': + return False + elif not prev.endswith('.'): + return False + else: + return True + + diff --git a/spacy/munge/read_ptb.py b/spacy/munge/read_ptb.py new file mode 100644 index 000000000..609397ba0 --- /dev/null +++ b/spacy/munge/read_ptb.py @@ -0,0 +1,65 @@ +import re +import os +from os import path + + +def parse(sent_text, strip_bad_periods=False): + sent_text = sent_text.strip() + assert sent_text and sent_text.startswith('(') + open_brackets = [] + brackets = [] + bracketsRE = re.compile(r'(\()([^\s\)\(]+)|([^\s\)\(]+)?(\))') + word_i = 0 + words = [] + # Remove outermost bracket + if sent_text.startswith('(('): + sent_text = sent_text.replace('((', '( (', 1) + for match in bracketsRE.finditer(sent_text[2:-1]): + open_, label, text, close = match.groups() + if open_: + assert not close + assert label.strip() + open_brackets.append((label, word_i)) + else: + assert close + label, start = open_brackets.pop() + assert label.strip() + if strip_bad_periods and words and _is_bad_period(words[-1], text): + continue + # Traces leave 0-width bracket, but no token + if text and label != '-NONE-': + words.append(text) + word_i += 1 + else: + brackets.append((label, start, word_i)) + return words, brackets + + +def _is_bad_period(prev, period): + if period != '.': + return False + elif prev == '.': + return False + elif not prev.endswith('.'): + return False + else: + return True + + +def split(text): + sentences = [] + current = [] + + for line in text.strip().split('\n'): + line = line.rstrip() + if not line: + continue + # Detect the start of sentences by line starting with ( + # This is messy, but it keeps bracket parsing at the sentence level + if line.startswith('(') and current: + sentences.append('\n'.join(current)) + current = [] + current.append(line) + if current: + sentences.append('\n'.join(current)) + return sentences