Refactor conllu script

This commit is contained in:
Matthew Honnibal 2018-02-25 12:48:22 +01:00
parent c388833ca6
commit 44e496a82e

View File

@ -4,8 +4,12 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import plac import plac
import tqdm import tqdm
import attr
from pathlib import Path
import re import re
import sys import sys
import json
import spacy import spacy
import spacy.util import spacy.util
from spacy.tokens import Token, Doc from spacy.tokens import Token, Doc
@ -40,32 +44,9 @@ def minibatch_by_words(items, size=5000):
batch.append((doc, gold)) batch.append((doc, gold))
yield batch yield batch
################
def get_token_acc(docs, golds): # Data reading #
'''Quick function to evaluate tokenization accuracy.''' ################
miss = 0
hit = 0
for doc, gold in zip(docs, golds):
for i in range(len(doc)):
token = doc[i]
align = gold.words[i]
if align == None:
miss += 1
else:
hit += 1
return miss, hit
def golds_to_gold_tuples(docs, golds):
'''Get out the annoying 'tuples' format used by begin_training, given the
GoldParse objects.'''
tuples = []
for doc, gold in zip(docs, golds):
text = doc.text
ids, words, tags, heads, labels, iob = zip(*gold.orig_annot)
sents = [((ids, words, tags, heads, labels, iob), [])]
tuples.append((text, sents))
return tuples
def split_text(text): def split_text(text):
return [par.strip().replace('\n', ' ') return [par.strip().replace('\n', ' ')
@ -127,34 +108,6 @@ def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False,
return docs, golds return docs, golds
def _make_gold(nlp, text, sent_annots):
# Flatten the conll annotations, and adjust the head indices
flat = defaultdict(list)
for sent in sent_annots:
flat['heads'].extend(len(flat['words'])+head for head in sent['heads'])
for field in ['words', 'tags', 'deps', 'entities', 'spaces']:
flat[field].extend(sent[field])
# Construct text if necessary
assert len(flat['words']) == len(flat['spaces'])
if text is None:
text = ''.join(word+' '*space for word, space in zip(flat['words'], flat['spaces']))
doc = nlp.make_doc(text)
flat.pop('spaces')
gold = GoldParse(doc, **flat)
#for annot in gold.orig_annot:
# print(annot)
#for i in range(len(doc)):
# print(doc[i].text, gold.words[i], gold.labels[i], gold.heads[i])
return doc, gold
def refresh_docs(docs):
vocab = docs[0].vocab
return [Doc(vocab, words=[t.text for t in doc],
spaces=[t.whitespace_ for t in doc])
for doc in docs]
def read_conllu(file_): def read_conllu(file_):
docs = [] docs = []
sent = [] sent = []
@ -179,6 +132,52 @@ def read_conllu(file_):
return docs return docs
def _make_gold(nlp, text, sent_annots):
# Flatten the conll annotations, and adjust the head indices
flat = defaultdict(list)
for sent in sent_annots:
flat['heads'].extend(len(flat['words'])+head for head in sent['heads'])
for field in ['words', 'tags', 'deps', 'entities', 'spaces']:
flat[field].extend(sent[field])
# Construct text if necessary
assert len(flat['words']) == len(flat['spaces'])
if text is None:
text = ''.join(word+' '*space for word, space in zip(flat['words'], flat['spaces']))
doc = nlp.make_doc(text)
flat.pop('spaces')
gold = GoldParse(doc, **flat)
#for annot in gold.orig_annot:
# print(annot)
#for i in range(len(doc)):
# print(doc[i].text, gold.words[i], gold.labels[i], gold.heads[i])
return doc, gold
#############################
# Data transforms for spaCy #
#############################
def golds_to_gold_tuples(docs, golds):
'''Get out the annoying 'tuples' format used by begin_training, given the
GoldParse objects.'''
tuples = []
for doc, gold in zip(docs, golds):
text = doc.text
ids, words, tags, heads, labels, iob = zip(*gold.orig_annot)
sents = [((ids, words, tags, heads, labels, iob), [])]
tuples.append((text, sents))
return tuples
def refresh_docs(docs):
vocab = docs[0].vocab
return [Doc(vocab, words=[t.text for t in doc],
spaces=[t.whitespace_ for t in doc])
for doc in docs]
##############
# Evaluation #
##############
def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False, def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False,
joint_sbd=True, limit=None): joint_sbd=True, limit=None):
with open(text_loc) as text_file: with open(text_loc) as text_file:
@ -265,25 +264,24 @@ Token.set_extension('begins_fused', default=False)
Token.set_extension('inside_fused', default=False) Token.set_extension('inside_fused', default=False)
def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc, ##################
output_loc): # Initialization #
if lang == 'en': ##################
def load_nlp(corpus, config):
lang = corpus.split('_')[0]
nlp = spacy.blank(lang) nlp = spacy.blank(lang)
vec_nlp = spacy.util.load_model('spacy/data/en_core_web_lg/en_core_web_lg-2.0.0') if config.vectors:
nlp.vocab.vectors = vec_nlp.vocab.vectors nlp.vocab.from_disk(config.vectors / 'vocab')
for lex in vec_nlp.vocab: return nlp
_ = nlp.vocab[lex.orth_]
vec_nlp = None def initialize_pipeline(nlp, docs, golds, config):
else:
nlp = spacy.load(lang)
with open(conllu_train_loc) as conllu_file:
with open(text_train_loc) as text_file:
docs, golds = read_data(nlp, conllu_file, text_file,
oracle_segments=False, raw_text=True,
max_doc_length=10, limit=None)
print("Create parser") print("Create parser")
nlp.add_pipe(nlp.create_pipe('parser')) nlp.add_pipe(nlp.create_pipe('parser'))
if config.multitask_tag:
nlp.parser.add_multitask_objective('tag') nlp.parser.add_multitask_objective('tag')
if config.multitask_sent:
nlp.parser.add_multitask_objective('sent_start') nlp.parser.add_multitask_objective('sent_start')
nlp.parser.moves.add_action(2, 'subtok') nlp.parser.moves.add_action(2, 'subtok')
nlp.add_pipe(nlp.create_pipe('tagger')) nlp.add_pipe(nlp.create_pipe('tagger'))
@ -291,7 +289,6 @@ def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
for tag in gold.tags: for tag in gold.tags:
if tag is not None: if tag is not None:
nlp.tagger.add_label(tag) nlp.tagger.add_label(tag)
optimizer = nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds))
# Replace labels that didn't make the frequency cutoff # Replace labels that didn't make the frequency cutoff
actions = set(nlp.parser.labels) actions = set(nlp.parser.labels)
label_set = set([act.split('-')[1] for act in actions if '-' in act]) label_set = set([act.split('-')[1] for act in actions if '-' in act])
@ -299,38 +296,92 @@ def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
for i, label in enumerate(gold.labels): for i, label in enumerate(gold.labels):
if label is not None and label not in label_set: if label is not None and label not in label_set:
gold.labels[i] = label.split('||')[0] gold.labels[i] = label.split('||')[0]
return nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds))
########################
# Command line helpers #
########################
@attr.s
class Config(object):
vectors = attr.ib(default=None)
max_doc_length = attr.ib(default=10)
multitask_tag = attr.ib(default=True)
multitask_sent = attr.ib(default=True)
nr_epoch = attr.ib(default=30)
batch_size = attr.ib(default=1000)
dropout = attr.ib(default=0.2)
@classmethod
def load(cls, loc):
with Path(loc).open('r', encoding='utf8') as file_:
cfg = json.load(file_)
return cls(**cfg)
class Dataset(object):
def __init__(self, path, section):
self.path = path
self.section = section
self.conllu = None
self.text = None
for file_path in self.path.iterdir():
name = file_path.parts[-1]
if section in name and name.endswith('conllu'):
self.conllu = file_path
elif section in name and name.endswith('txt'):
self.text = file_path
if self.conllu is None:
msg = "Could not find .txt file in {path} for {section}"
raise IOError(msg.format(section=section, path=path))
if self.text is None:
msg = "Could not find .txt file in {path} for {section}"
self.lang = self.conllu.parts[-1].split('-')[0].split('_')[0]
class TreebankPaths(object):
def __init__(self, ud_path, treebank, **cfg):
self.train = Dataset(ud_path / treebank, 'train')
self.dev = Dataset(ud_path / treebank, 'dev')
self.lang = self.train.lang
@plac.annotations(
ud_dir=("Path to Universal Dependencies corpus", "positional", None, Path),
config=("Path to json formatted config file", "positional", None, Config.load),
corpus=("UD corpus to train and evaluate on, e.g. en, es_ancora, etc",
"positional", None, str),
parses=("Path to write the development parses", "positional", None, Path)
)
def main(ud_dir, corpus, config, parses='/tmp/dev.conllu'):
paths = TreebankPaths(ud_dir, corpus)
nlp = load_nlp(paths.lang, config)
docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(),
config)
optimizer = initialize_pipeline(nlp, docs, golds, config)
n_train_words = sum(len(doc) for doc in docs) n_train_words = sum(len(doc) for doc in docs)
print(n_train_words) print("Begin training (%d words)" % n_train_words)
print("Begin training") for i in range(config.nr_epoch):
# Batch size starts at 1 and grows, so that we make updates quickly
# at the beginning of training.
batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 1),
spacy.util.env_opt('batch_to', 8),
spacy.util.env_opt('batch_compound', 1.001))
for i in range(30):
docs = refresh_docs(docs) docs = refresh_docs(docs)
batches = minibatch_by_words(list(zip(docs, golds)), size=1000) batches = minibatch_by_words(list(zip(docs, golds)), size=config.batch_size)
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
losses = {} losses = {}
for batch in batches: for batch in tqdm.tqdm(batches, total=n_train_words//config.batch_size):
if not batch: if not batch:
continue continue
batch_docs, batch_gold = zip(*batch) batch_docs, batch_gold = zip(*batch)
nlp.update(batch_docs, batch_gold, sgd=optimizer, nlp.update(batch_docs, batch_gold, sgd=optimizer,
drop=0.2, losses=losses) drop=config.dropout, losses=losses)
pbar.update(sum(len(doc) for doc in batch_docs))
with nlp.use_params(optimizer.averages): with nlp.use_params(optimizer.averages):
dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc, dev_docs, scorer = parse_dev_data(nlp, paths.dev.text, paths.dev.conllu,
oracle_segments=False, joint_sbd=True) **attr.asdict(config))
print_progress(i, losses, scorer) print_progress(i, losses, scorer)
with open(output_loc, 'w') as file_: with open(output_loc, 'w') as file_:
print_conllu(dev_docs, file_) print_conllu(dev_docs, file_)
with open('/tmp/train.conllu', 'w') as file_:
print_conllu(list(nlp.pipe([d.text for d in batch_docs])), file_)
if __name__ == '__main__': if __name__ == '__main__':