Generalize conllu script. Now handling Chinese (maybe badly)

This commit is contained in:
Matthew Honnibal 2018-02-24 16:04:27 +01:00
parent 5cc3bd1c1d
commit 8adeea3746

View File

@ -11,7 +11,7 @@ import spacy.util
from spacy.tokens import Doc from spacy.tokens import Doc
from spacy.gold import GoldParse, minibatch from spacy.gold import GoldParse, minibatch
from spacy.syntax.nonproj import projectivize from spacy.syntax.nonproj import projectivize
from collections import Counter from collections import defaultdict, Counter
from timeit import default_timer as timer from timeit import default_timer as timer
from spacy.matcher import Matcher from spacy.matcher import Matcher
@ -56,7 +56,7 @@ def split_text(text):
def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False, def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False,
limit=None): max_doc_length=None, limit=None):
'''Read the CONLLU format into (Doc, GoldParse) tuples. If raw_text=True, '''Read the CONLLU format into (Doc, GoldParse) tuples. If raw_text=True,
include Doc objects created using nlp.make_doc and then aligned against include Doc objects created using nlp.make_doc and then aligned against
the gold-standard sequences. If oracle_segments=True, include Doc objects the gold-standard sequences. If oracle_segments=True, include Doc objects
@ -70,51 +70,67 @@ def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False,
docs = [] docs = []
golds = [] golds = []
for doc_id, (text, cd) in enumerate(zip(paragraphs, conllu)): for doc_id, (text, cd) in enumerate(zip(paragraphs, conllu)):
doc_words = [] sent_annots = []
doc_tags = []
doc_heads = []
doc_deps = []
doc_ents = []
for cs in cd: for cs in cd:
sent_words = [] sent = defaultdict(list)
sent_tags = [] for id_, word, lemma, pos, tag, morph, head, dep, _, space_after in cs:
sent_heads = []
sent_deps = []
for id_, word, lemma, pos, tag, morph, head, dep, _1, _2 in cs:
if '.' in id_: if '.' in id_:
continue continue
if '-' in id_: if '-' in id_:
continue continue
id_ = int(id_)-1 id_ = int(id_)-1
head = int(head)-1 if head != '0' else id_ head = int(head)-1 if head != '0' else id_
sent_words.append(word) sent['words'].append(word)
sent_tags.append(tag) sent['tags'].append(tag)
sent_heads.append(head) sent['heads'].append(head)
sent_deps.append('ROOT' if dep == 'root' else dep) sent['deps'].append('ROOT' if dep == 'root' else dep)
sent['spaces'].append(space_after == '_')
sent['entities'] = ['-'] * len(sent['words'])
sent['heads'], sent['deps'] = projectivize(sent['heads'],
sent['deps'])
if oracle_segments: if oracle_segments:
sent_heads, sent_deps = projectivize(sent_heads, sent_deps) docs.append(Doc(nlp.vocab, words=sent['words'], spaces=sent['spaces']))
docs.append(Doc(nlp.vocab, words=sent_words)) golds.append(GoldParse(docs[-1], **sent))
golds.append(GoldParse(docs[-1], words=sent_words, heads=sent_heads,
tags=sent_tags, deps=sent_deps, sent_annots.append(sent)
entities=['-']*len(sent_words))) if raw_text and max_doc_length and len(sent_annots) >= max_doc_length:
for head in sent_heads: doc, gold = _make_gold(nlp, None, sent_annots)
doc_heads.append(len(doc_words)+head) sent_annots = []
doc_words.extend(sent_words) docs.append(doc)
doc_tags.extend(sent_tags) golds.append(gold)
doc_deps.extend(sent_deps) if limit and len(docs) >= limit:
doc_ents.extend(['-']*len(sent_words)) return docs, golds
# Create a GoldParse object for the sentence
doc_heads, doc_deps = projectivize(doc_heads, doc_deps) if raw_text and sent_annots:
if raw_text: doc, gold = _make_gold(nlp, None, sent_annots)
docs.append(nlp.make_doc(text)) docs.append(doc)
golds.append(GoldParse(docs[-1], words=doc_words, tags=doc_tags, golds.append(gold)
heads=doc_heads, deps=doc_deps, if limit and len(docs) >= limit:
entities=doc_ents)) return docs, golds
if limit and doc_id >= limit:
break
return docs, golds return docs, golds
def _make_gold(nlp, text, sent_annots):
# Flatten the conll annotations, and adjust the head indices
flat = defaultdict(list)
for sent in sent_annots:
flat['heads'].extend(len(flat['words'])+head for head in sent['heads'])
for field in ['words', 'tags', 'deps', 'entities', 'spaces']:
flat[field].extend(sent[field])
# Construct text if necessary
assert len(flat['words']) == len(flat['spaces'])
if text is None:
text = ''.join(word+' '*space for word, space in zip(flat['words'], flat['spaces']))
doc = nlp.make_doc(text)
flat.pop('spaces')
gold = GoldParse(doc, **flat)
#for annot in gold.orig_annot:
# print(annot)
#for i in range(len(doc)):
# print(doc[i].text, gold.words[i], gold.labels[i], gold.heads[i])
return doc, gold
def refresh_docs(docs): def refresh_docs(docs):
vocab = docs[0].vocab vocab = docs[0].vocab
return [Doc(vocab, words=[t.text for t in doc], return [Doc(vocab, words=[t.text for t in doc],
@ -124,8 +140,8 @@ def refresh_docs(docs):
def read_conllu(file_): def read_conllu(file_):
docs = [] docs = []
doc = None
sent = [] sent = []
doc = []
for line in file_: for line in file_:
if line.startswith('# newdoc'): if line.startswith('# newdoc'):
if doc: if doc:
@ -135,29 +151,23 @@ def read_conllu(file_):
continue continue
elif not line.strip(): elif not line.strip():
if sent: if sent:
if doc is None: doc.append(sent)
docs.append([sent])
else:
doc.append(sent)
sent = [] sent = []
else: else:
sent.append(line.strip().split()) sent.append(line.strip().split())
if sent: if sent:
if doc is None: doc.append(sent)
docs.append([sent])
else:
doc.append(sent)
if doc: if doc:
docs.append(doc) docs.append(doc)
return docs return docs
def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False, def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False,
joint_sbd=True): joint_sbd=True, limit=None):
with open(text_loc) as text_file: with open(text_loc) as text_file:
with open(conllu_loc) as conllu_file: with open(conllu_loc) as conllu_file:
docs, golds = read_data(nlp, conllu_file, text_file, docs, golds = read_data(nlp, conllu_file, text_file,
oracle_segments=oracle_segments) oracle_segments=oracle_segments, limit=limit)
if joint_sbd: if joint_sbd:
pass pass
else: else:
@ -200,10 +210,11 @@ def print_conllu(docs, file_):
merger.add('SUBTOK', None, [{'DEP': 'subtok', 'op': '+'}]) merger.add('SUBTOK', None, [{'DEP': 'subtok', 'op': '+'}])
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
matches = merger(doc) matches = merger(doc)
spans = [(doc[start].idx, doc[end+1].idx+len(doc[end+1])) spans = [doc[start:end+1] for _, start, end in matches]
for (_, start, end) in matches if end < (len(doc)-1)] offsets = [(span.start_char, span.end_char) for span in spans]
for start_char, end_char in spans: for start_char, end_char in offsets:
doc.merge(start_char, end_char) doc.merge(start_char, end_char)
#print([t.text for t in doc])
file_.write("# newdoc id = {i}\n".format(i=i)) file_.write("# newdoc id = {i}\n".format(i=i))
for j, sent in enumerate(doc.sents): for j, sent in enumerate(doc.sents):
file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j)) file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j))
@ -232,7 +243,7 @@ def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
with open(text_train_loc) as text_file: with open(text_train_loc) as text_file:
docs, golds = read_data(nlp, conllu_file, text_file, docs, golds = read_data(nlp, conllu_file, text_file,
oracle_segments=False, raw_text=True, oracle_segments=False, raw_text=True,
limit=None) max_doc_length=10, limit=None)
print("Create parser") print("Create parser")
nlp.add_pipe(nlp.create_pipe('parser')) nlp.add_pipe(nlp.create_pipe('parser'))
nlp.parser.add_multitask_objective('tag') nlp.parser.add_multitask_objective('tag')
@ -257,7 +268,7 @@ def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
# Batch size starts at 1 and grows, so that we make updates quickly # Batch size starts at 1 and grows, so that we make updates quickly
# at the beginning of training. # at the beginning of training.
batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 1), batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 1),
spacy.util.env_opt('batch_to', 2), spacy.util.env_opt('batch_to', 8),
spacy.util.env_opt('batch_compound', 1.001)) spacy.util.env_opt('batch_compound', 1.001))
for i in range(30): for i in range(30):
docs = refresh_docs(docs) docs = refresh_docs(docs)
@ -275,13 +286,15 @@ def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
with nlp.use_params(optimizer.averages): with nlp.use_params(optimizer.averages):
dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc, dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc,
oracle_segments=False, joint_sbd=True) oracle_segments=False, joint_sbd=True,
limit=5)
print_progress(i, losses, scorer) print_progress(i, losses, scorer)
with open(output_loc, 'w') as file_: with open(output_loc, 'w') as file_:
print_conllu(dev_docs, file_) print_conllu(dev_docs, file_)
dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc, with open('/tmp/train.conllu', 'w') as file_:
oracle_segments=False, joint_sbd=False) print_conllu(list(nlp.pipe([d.text for d in batch_docs])), file_)
print_progress(i, losses, scorer)
if __name__ == '__main__': if __name__ == '__main__':