From 8adeea37462cf7f0f62ca6aa842ba344248dc0cb Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sat, 24 Feb 2018 16:04:27 +0100 Subject: [PATCH] Generalize conllu script. Now handling Chinese (maybe badly) --- examples/training/conllu.py | 125 ++++++++++++++++++++---------------- 1 file changed, 69 insertions(+), 56 deletions(-) diff --git a/examples/training/conllu.py b/examples/training/conllu.py index fd2a91222..3d07b2279 100644 --- a/examples/training/conllu.py +++ b/examples/training/conllu.py @@ -11,7 +11,7 @@ import spacy.util from spacy.tokens import Doc from spacy.gold import GoldParse, minibatch from spacy.syntax.nonproj import projectivize -from collections import Counter +from collections import defaultdict, Counter from timeit import default_timer as timer from spacy.matcher import Matcher @@ -56,7 +56,7 @@ def split_text(text): def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False, - limit=None): + max_doc_length=None, limit=None): '''Read the CONLLU format into (Doc, GoldParse) tuples. If raw_text=True, include Doc objects created using nlp.make_doc and then aligned against the gold-standard sequences. If oracle_segments=True, include Doc objects @@ -70,51 +70,67 @@ def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False, docs = [] golds = [] for doc_id, (text, cd) in enumerate(zip(paragraphs, conllu)): - doc_words = [] - doc_tags = [] - doc_heads = [] - doc_deps = [] - doc_ents = [] + sent_annots = [] for cs in cd: - sent_words = [] - sent_tags = [] - sent_heads = [] - sent_deps = [] - for id_, word, lemma, pos, tag, morph, head, dep, _1, _2 in cs: + sent = defaultdict(list) + for id_, word, lemma, pos, tag, morph, head, dep, _, space_after in cs: if '.' in id_: continue if '-' in id_: continue id_ = int(id_)-1 head = int(head)-1 if head != '0' else id_ - sent_words.append(word) - sent_tags.append(tag) - sent_heads.append(head) - sent_deps.append('ROOT' if dep == 'root' else dep) + sent['words'].append(word) + sent['tags'].append(tag) + sent['heads'].append(head) + sent['deps'].append('ROOT' if dep == 'root' else dep) + sent['spaces'].append(space_after == '_') + sent['entities'] = ['-'] * len(sent['words']) + sent['heads'], sent['deps'] = projectivize(sent['heads'], + sent['deps']) if oracle_segments: - sent_heads, sent_deps = projectivize(sent_heads, sent_deps) - docs.append(Doc(nlp.vocab, words=sent_words)) - golds.append(GoldParse(docs[-1], words=sent_words, heads=sent_heads, - tags=sent_tags, deps=sent_deps, - entities=['-']*len(sent_words))) - for head in sent_heads: - doc_heads.append(len(doc_words)+head) - doc_words.extend(sent_words) - doc_tags.extend(sent_tags) - doc_deps.extend(sent_deps) - doc_ents.extend(['-']*len(sent_words)) - # Create a GoldParse object for the sentence - doc_heads, doc_deps = projectivize(doc_heads, doc_deps) - if raw_text: - docs.append(nlp.make_doc(text)) - golds.append(GoldParse(docs[-1], words=doc_words, tags=doc_tags, - heads=doc_heads, deps=doc_deps, - entities=doc_ents)) - if limit and doc_id >= limit: - break + docs.append(Doc(nlp.vocab, words=sent['words'], spaces=sent['spaces'])) + golds.append(GoldParse(docs[-1], **sent)) + + sent_annots.append(sent) + if raw_text and max_doc_length and len(sent_annots) >= max_doc_length: + doc, gold = _make_gold(nlp, None, sent_annots) + sent_annots = [] + docs.append(doc) + golds.append(gold) + if limit and len(docs) >= limit: + return docs, golds + + if raw_text and sent_annots: + doc, gold = _make_gold(nlp, None, sent_annots) + docs.append(doc) + golds.append(gold) + if limit and len(docs) >= limit: + return docs, golds return docs, golds +def _make_gold(nlp, text, sent_annots): + # Flatten the conll annotations, and adjust the head indices + flat = defaultdict(list) + for sent in sent_annots: + flat['heads'].extend(len(flat['words'])+head for head in sent['heads']) + for field in ['words', 'tags', 'deps', 'entities', 'spaces']: + flat[field].extend(sent[field]) + # Construct text if necessary + assert len(flat['words']) == len(flat['spaces']) + if text is None: + text = ''.join(word+' '*space for word, space in zip(flat['words'], flat['spaces'])) + doc = nlp.make_doc(text) + flat.pop('spaces') + gold = GoldParse(doc, **flat) + #for annot in gold.orig_annot: + # print(annot) + #for i in range(len(doc)): + # print(doc[i].text, gold.words[i], gold.labels[i], gold.heads[i]) + return doc, gold + + def refresh_docs(docs): vocab = docs[0].vocab return [Doc(vocab, words=[t.text for t in doc], @@ -124,8 +140,8 @@ def refresh_docs(docs): def read_conllu(file_): docs = [] - doc = None sent = [] + doc = [] for line in file_: if line.startswith('# newdoc'): if doc: @@ -135,29 +151,23 @@ def read_conllu(file_): continue elif not line.strip(): if sent: - if doc is None: - docs.append([sent]) - else: - doc.append(sent) + doc.append(sent) sent = [] else: sent.append(line.strip().split()) if sent: - if doc is None: - docs.append([sent]) - else: - doc.append(sent) + doc.append(sent) if doc: docs.append(doc) return docs def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False, - joint_sbd=True): + joint_sbd=True, limit=None): with open(text_loc) as text_file: with open(conllu_loc) as conllu_file: docs, golds = read_data(nlp, conllu_file, text_file, - oracle_segments=oracle_segments) + oracle_segments=oracle_segments, limit=limit) if joint_sbd: pass else: @@ -200,10 +210,11 @@ def print_conllu(docs, file_): merger.add('SUBTOK', None, [{'DEP': 'subtok', 'op': '+'}]) for i, doc in enumerate(docs): matches = merger(doc) - spans = [(doc[start].idx, doc[end+1].idx+len(doc[end+1])) - for (_, start, end) in matches if end < (len(doc)-1)] - for start_char, end_char in spans: + spans = [doc[start:end+1] for _, start, end in matches] + offsets = [(span.start_char, span.end_char) for span in spans] + for start_char, end_char in offsets: doc.merge(start_char, end_char) + #print([t.text for t in doc]) file_.write("# newdoc id = {i}\n".format(i=i)) for j, sent in enumerate(doc.sents): file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j)) @@ -232,7 +243,7 @@ def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc, with open(text_train_loc) as text_file: docs, golds = read_data(nlp, conllu_file, text_file, oracle_segments=False, raw_text=True, - limit=None) + max_doc_length=10, limit=None) print("Create parser") nlp.add_pipe(nlp.create_pipe('parser')) nlp.parser.add_multitask_objective('tag') @@ -257,7 +268,7 @@ def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc, # Batch size starts at 1 and grows, so that we make updates quickly # at the beginning of training. batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 1), - spacy.util.env_opt('batch_to', 2), + spacy.util.env_opt('batch_to', 8), spacy.util.env_opt('batch_compound', 1.001)) for i in range(30): docs = refresh_docs(docs) @@ -275,13 +286,15 @@ def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc, with nlp.use_params(optimizer.averages): dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc, - oracle_segments=False, joint_sbd=True) + oracle_segments=False, joint_sbd=True, + limit=5) print_progress(i, losses, scorer) with open(output_loc, 'w') as file_: print_conllu(dev_docs, file_) - dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc, - oracle_segments=False, joint_sbd=False) - print_progress(i, losses, scorer) + with open('/tmp/train.conllu', 'w') as file_: + print_conllu(list(nlp.pipe([d.text for d in batch_docs])), file_) + + if __name__ == '__main__':