mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-29 17:33:10 +03:00
Generalize conllu script. Now handling Chinese (maybe badly)
This commit is contained in:
parent
5cc3bd1c1d
commit
8adeea3746
|
@ -11,7 +11,7 @@ import spacy.util
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
from spacy.gold import GoldParse, minibatch
|
from spacy.gold import GoldParse, minibatch
|
||||||
from spacy.syntax.nonproj import projectivize
|
from spacy.syntax.nonproj import projectivize
|
||||||
from collections import Counter
|
from collections import defaultdict, Counter
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
from spacy.matcher import Matcher
|
from spacy.matcher import Matcher
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ def split_text(text):
|
||||||
|
|
||||||
|
|
||||||
def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False,
|
def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False,
|
||||||
limit=None):
|
max_doc_length=None, limit=None):
|
||||||
'''Read the CONLLU format into (Doc, GoldParse) tuples. If raw_text=True,
|
'''Read the CONLLU format into (Doc, GoldParse) tuples. If raw_text=True,
|
||||||
include Doc objects created using nlp.make_doc and then aligned against
|
include Doc objects created using nlp.make_doc and then aligned against
|
||||||
the gold-standard sequences. If oracle_segments=True, include Doc objects
|
the gold-standard sequences. If oracle_segments=True, include Doc objects
|
||||||
|
@ -70,51 +70,67 @@ def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False,
|
||||||
docs = []
|
docs = []
|
||||||
golds = []
|
golds = []
|
||||||
for doc_id, (text, cd) in enumerate(zip(paragraphs, conllu)):
|
for doc_id, (text, cd) in enumerate(zip(paragraphs, conllu)):
|
||||||
doc_words = []
|
sent_annots = []
|
||||||
doc_tags = []
|
|
||||||
doc_heads = []
|
|
||||||
doc_deps = []
|
|
||||||
doc_ents = []
|
|
||||||
for cs in cd:
|
for cs in cd:
|
||||||
sent_words = []
|
sent = defaultdict(list)
|
||||||
sent_tags = []
|
for id_, word, lemma, pos, tag, morph, head, dep, _, space_after in cs:
|
||||||
sent_heads = []
|
|
||||||
sent_deps = []
|
|
||||||
for id_, word, lemma, pos, tag, morph, head, dep, _1, _2 in cs:
|
|
||||||
if '.' in id_:
|
if '.' in id_:
|
||||||
continue
|
continue
|
||||||
if '-' in id_:
|
if '-' in id_:
|
||||||
continue
|
continue
|
||||||
id_ = int(id_)-1
|
id_ = int(id_)-1
|
||||||
head = int(head)-1 if head != '0' else id_
|
head = int(head)-1 if head != '0' else id_
|
||||||
sent_words.append(word)
|
sent['words'].append(word)
|
||||||
sent_tags.append(tag)
|
sent['tags'].append(tag)
|
||||||
sent_heads.append(head)
|
sent['heads'].append(head)
|
||||||
sent_deps.append('ROOT' if dep == 'root' else dep)
|
sent['deps'].append('ROOT' if dep == 'root' else dep)
|
||||||
|
sent['spaces'].append(space_after == '_')
|
||||||
|
sent['entities'] = ['-'] * len(sent['words'])
|
||||||
|
sent['heads'], sent['deps'] = projectivize(sent['heads'],
|
||||||
|
sent['deps'])
|
||||||
if oracle_segments:
|
if oracle_segments:
|
||||||
sent_heads, sent_deps = projectivize(sent_heads, sent_deps)
|
docs.append(Doc(nlp.vocab, words=sent['words'], spaces=sent['spaces']))
|
||||||
docs.append(Doc(nlp.vocab, words=sent_words))
|
golds.append(GoldParse(docs[-1], **sent))
|
||||||
golds.append(GoldParse(docs[-1], words=sent_words, heads=sent_heads,
|
|
||||||
tags=sent_tags, deps=sent_deps,
|
sent_annots.append(sent)
|
||||||
entities=['-']*len(sent_words)))
|
if raw_text and max_doc_length and len(sent_annots) >= max_doc_length:
|
||||||
for head in sent_heads:
|
doc, gold = _make_gold(nlp, None, sent_annots)
|
||||||
doc_heads.append(len(doc_words)+head)
|
sent_annots = []
|
||||||
doc_words.extend(sent_words)
|
docs.append(doc)
|
||||||
doc_tags.extend(sent_tags)
|
golds.append(gold)
|
||||||
doc_deps.extend(sent_deps)
|
if limit and len(docs) >= limit:
|
||||||
doc_ents.extend(['-']*len(sent_words))
|
return docs, golds
|
||||||
# Create a GoldParse object for the sentence
|
|
||||||
doc_heads, doc_deps = projectivize(doc_heads, doc_deps)
|
if raw_text and sent_annots:
|
||||||
if raw_text:
|
doc, gold = _make_gold(nlp, None, sent_annots)
|
||||||
docs.append(nlp.make_doc(text))
|
docs.append(doc)
|
||||||
golds.append(GoldParse(docs[-1], words=doc_words, tags=doc_tags,
|
golds.append(gold)
|
||||||
heads=doc_heads, deps=doc_deps,
|
if limit and len(docs) >= limit:
|
||||||
entities=doc_ents))
|
return docs, golds
|
||||||
if limit and doc_id >= limit:
|
|
||||||
break
|
|
||||||
return docs, golds
|
return docs, golds
|
||||||
|
|
||||||
|
|
||||||
|
def _make_gold(nlp, text, sent_annots):
|
||||||
|
# Flatten the conll annotations, and adjust the head indices
|
||||||
|
flat = defaultdict(list)
|
||||||
|
for sent in sent_annots:
|
||||||
|
flat['heads'].extend(len(flat['words'])+head for head in sent['heads'])
|
||||||
|
for field in ['words', 'tags', 'deps', 'entities', 'spaces']:
|
||||||
|
flat[field].extend(sent[field])
|
||||||
|
# Construct text if necessary
|
||||||
|
assert len(flat['words']) == len(flat['spaces'])
|
||||||
|
if text is None:
|
||||||
|
text = ''.join(word+' '*space for word, space in zip(flat['words'], flat['spaces']))
|
||||||
|
doc = nlp.make_doc(text)
|
||||||
|
flat.pop('spaces')
|
||||||
|
gold = GoldParse(doc, **flat)
|
||||||
|
#for annot in gold.orig_annot:
|
||||||
|
# print(annot)
|
||||||
|
#for i in range(len(doc)):
|
||||||
|
# print(doc[i].text, gold.words[i], gold.labels[i], gold.heads[i])
|
||||||
|
return doc, gold
|
||||||
|
|
||||||
|
|
||||||
def refresh_docs(docs):
|
def refresh_docs(docs):
|
||||||
vocab = docs[0].vocab
|
vocab = docs[0].vocab
|
||||||
return [Doc(vocab, words=[t.text for t in doc],
|
return [Doc(vocab, words=[t.text for t in doc],
|
||||||
|
@ -124,8 +140,8 @@ def refresh_docs(docs):
|
||||||
|
|
||||||
def read_conllu(file_):
|
def read_conllu(file_):
|
||||||
docs = []
|
docs = []
|
||||||
doc = None
|
|
||||||
sent = []
|
sent = []
|
||||||
|
doc = []
|
||||||
for line in file_:
|
for line in file_:
|
||||||
if line.startswith('# newdoc'):
|
if line.startswith('# newdoc'):
|
||||||
if doc:
|
if doc:
|
||||||
|
@ -135,29 +151,23 @@ def read_conllu(file_):
|
||||||
continue
|
continue
|
||||||
elif not line.strip():
|
elif not line.strip():
|
||||||
if sent:
|
if sent:
|
||||||
if doc is None:
|
doc.append(sent)
|
||||||
docs.append([sent])
|
|
||||||
else:
|
|
||||||
doc.append(sent)
|
|
||||||
sent = []
|
sent = []
|
||||||
else:
|
else:
|
||||||
sent.append(line.strip().split())
|
sent.append(line.strip().split())
|
||||||
if sent:
|
if sent:
|
||||||
if doc is None:
|
doc.append(sent)
|
||||||
docs.append([sent])
|
|
||||||
else:
|
|
||||||
doc.append(sent)
|
|
||||||
if doc:
|
if doc:
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
|
||||||
def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False,
|
def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False,
|
||||||
joint_sbd=True):
|
joint_sbd=True, limit=None):
|
||||||
with open(text_loc) as text_file:
|
with open(text_loc) as text_file:
|
||||||
with open(conllu_loc) as conllu_file:
|
with open(conllu_loc) as conllu_file:
|
||||||
docs, golds = read_data(nlp, conllu_file, text_file,
|
docs, golds = read_data(nlp, conllu_file, text_file,
|
||||||
oracle_segments=oracle_segments)
|
oracle_segments=oracle_segments, limit=limit)
|
||||||
if joint_sbd:
|
if joint_sbd:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
@ -200,10 +210,11 @@ def print_conllu(docs, file_):
|
||||||
merger.add('SUBTOK', None, [{'DEP': 'subtok', 'op': '+'}])
|
merger.add('SUBTOK', None, [{'DEP': 'subtok', 'op': '+'}])
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
matches = merger(doc)
|
matches = merger(doc)
|
||||||
spans = [(doc[start].idx, doc[end+1].idx+len(doc[end+1]))
|
spans = [doc[start:end+1] for _, start, end in matches]
|
||||||
for (_, start, end) in matches if end < (len(doc)-1)]
|
offsets = [(span.start_char, span.end_char) for span in spans]
|
||||||
for start_char, end_char in spans:
|
for start_char, end_char in offsets:
|
||||||
doc.merge(start_char, end_char)
|
doc.merge(start_char, end_char)
|
||||||
|
#print([t.text for t in doc])
|
||||||
file_.write("# newdoc id = {i}\n".format(i=i))
|
file_.write("# newdoc id = {i}\n".format(i=i))
|
||||||
for j, sent in enumerate(doc.sents):
|
for j, sent in enumerate(doc.sents):
|
||||||
file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j))
|
file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j))
|
||||||
|
@ -232,7 +243,7 @@ def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
|
||||||
with open(text_train_loc) as text_file:
|
with open(text_train_loc) as text_file:
|
||||||
docs, golds = read_data(nlp, conllu_file, text_file,
|
docs, golds = read_data(nlp, conllu_file, text_file,
|
||||||
oracle_segments=False, raw_text=True,
|
oracle_segments=False, raw_text=True,
|
||||||
limit=None)
|
max_doc_length=10, limit=None)
|
||||||
print("Create parser")
|
print("Create parser")
|
||||||
nlp.add_pipe(nlp.create_pipe('parser'))
|
nlp.add_pipe(nlp.create_pipe('parser'))
|
||||||
nlp.parser.add_multitask_objective('tag')
|
nlp.parser.add_multitask_objective('tag')
|
||||||
|
@ -257,7 +268,7 @@ def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
|
||||||
# Batch size starts at 1 and grows, so that we make updates quickly
|
# Batch size starts at 1 and grows, so that we make updates quickly
|
||||||
# at the beginning of training.
|
# at the beginning of training.
|
||||||
batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 1),
|
batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 1),
|
||||||
spacy.util.env_opt('batch_to', 2),
|
spacy.util.env_opt('batch_to', 8),
|
||||||
spacy.util.env_opt('batch_compound', 1.001))
|
spacy.util.env_opt('batch_compound', 1.001))
|
||||||
for i in range(30):
|
for i in range(30):
|
||||||
docs = refresh_docs(docs)
|
docs = refresh_docs(docs)
|
||||||
|
@ -275,13 +286,15 @@ def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
|
||||||
|
|
||||||
with nlp.use_params(optimizer.averages):
|
with nlp.use_params(optimizer.averages):
|
||||||
dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc,
|
dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc,
|
||||||
oracle_segments=False, joint_sbd=True)
|
oracle_segments=False, joint_sbd=True,
|
||||||
|
limit=5)
|
||||||
print_progress(i, losses, scorer)
|
print_progress(i, losses, scorer)
|
||||||
with open(output_loc, 'w') as file_:
|
with open(output_loc, 'w') as file_:
|
||||||
print_conllu(dev_docs, file_)
|
print_conllu(dev_docs, file_)
|
||||||
dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc,
|
with open('/tmp/train.conllu', 'w') as file_:
|
||||||
oracle_segments=False, joint_sbd=False)
|
print_conllu(list(nlp.pipe([d.text for d in batch_docs])), file_)
|
||||||
print_progress(i, losses, scorer)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in New Issue
Block a user