Refactor conllu script

This commit is contained in:
Matthew Honnibal 2018-02-25 12:48:22 +01:00
parent c388833ca6
commit 44e496a82e

View File

@ -4,8 +4,12 @@
from __future__ import unicode_literals
import plac
import tqdm
import attr
from pathlib import Path
import re
import sys
import json
import spacy
import spacy.util
from spacy.tokens import Token, Doc
@ -40,32 +44,9 @@ def minibatch_by_words(items, size=5000):
batch.append((doc, gold))
yield batch
def get_token_acc(docs, golds):
'''Quick function to evaluate tokenization accuracy.'''
miss = 0
hit = 0
for doc, gold in zip(docs, golds):
for i in range(len(doc)):
token = doc[i]
align = gold.words[i]
if align == None:
miss += 1
else:
hit += 1
return miss, hit
def golds_to_gold_tuples(docs, golds):
'''Get out the annoying 'tuples' format used by begin_training, given the
GoldParse objects.'''
tuples = []
for doc, gold in zip(docs, golds):
text = doc.text
ids, words, tags, heads, labels, iob = zip(*gold.orig_annot)
sents = [((ids, words, tags, heads, labels, iob), [])]
tuples.append((text, sents))
return tuples
################
# Data reading #
################
def split_text(text):
return [par.strip().replace('\n', ' ')
@ -127,34 +108,6 @@ def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False,
return docs, golds
def _make_gold(nlp, text, sent_annots):
# Flatten the conll annotations, and adjust the head indices
flat = defaultdict(list)
for sent in sent_annots:
flat['heads'].extend(len(flat['words'])+head for head in sent['heads'])
for field in ['words', 'tags', 'deps', 'entities', 'spaces']:
flat[field].extend(sent[field])
# Construct text if necessary
assert len(flat['words']) == len(flat['spaces'])
if text is None:
text = ''.join(word+' '*space for word, space in zip(flat['words'], flat['spaces']))
doc = nlp.make_doc(text)
flat.pop('spaces')
gold = GoldParse(doc, **flat)
#for annot in gold.orig_annot:
# print(annot)
#for i in range(len(doc)):
# print(doc[i].text, gold.words[i], gold.labels[i], gold.heads[i])
return doc, gold
def refresh_docs(docs):
vocab = docs[0].vocab
return [Doc(vocab, words=[t.text for t in doc],
spaces=[t.whitespace_ for t in doc])
for doc in docs]
def read_conllu(file_):
docs = []
sent = []
@ -179,6 +132,52 @@ def read_conllu(file_):
return docs
def _make_gold(nlp, text, sent_annots):
# Flatten the conll annotations, and adjust the head indices
flat = defaultdict(list)
for sent in sent_annots:
flat['heads'].extend(len(flat['words'])+head for head in sent['heads'])
for field in ['words', 'tags', 'deps', 'entities', 'spaces']:
flat[field].extend(sent[field])
# Construct text if necessary
assert len(flat['words']) == len(flat['spaces'])
if text is None:
text = ''.join(word+' '*space for word, space in zip(flat['words'], flat['spaces']))
doc = nlp.make_doc(text)
flat.pop('spaces')
gold = GoldParse(doc, **flat)
#for annot in gold.orig_annot:
# print(annot)
#for i in range(len(doc)):
# print(doc[i].text, gold.words[i], gold.labels[i], gold.heads[i])
return doc, gold
#############################
# Data transforms for spaCy #
#############################
def golds_to_gold_tuples(docs, golds):
'''Get out the annoying 'tuples' format used by begin_training, given the
GoldParse objects.'''
tuples = []
for doc, gold in zip(docs, golds):
text = doc.text
ids, words, tags, heads, labels, iob = zip(*gold.orig_annot)
sents = [((ids, words, tags, heads, labels, iob), [])]
tuples.append((text, sents))
return tuples
def refresh_docs(docs):
vocab = docs[0].vocab
return [Doc(vocab, words=[t.text for t in doc],
spaces=[t.whitespace_ for t in doc])
for doc in docs]
##############
# Evaluation #
##############
def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False,
joint_sbd=True, limit=None):
with open(text_loc) as text_file:
@ -265,25 +264,24 @@ Token.set_extension('begins_fused', default=False)
Token.set_extension('inside_fused', default=False)
def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
output_loc):
if lang == 'en':
##################
# Initialization #
##################
def load_nlp(corpus, config):
lang = corpus.split('_')[0]
nlp = spacy.blank(lang)
vec_nlp = spacy.util.load_model('spacy/data/en_core_web_lg/en_core_web_lg-2.0.0')
nlp.vocab.vectors = vec_nlp.vocab.vectors
for lex in vec_nlp.vocab:
_ = nlp.vocab[lex.orth_]
vec_nlp = None
else:
nlp = spacy.load(lang)
with open(conllu_train_loc) as conllu_file:
with open(text_train_loc) as text_file:
docs, golds = read_data(nlp, conllu_file, text_file,
oracle_segments=False, raw_text=True,
max_doc_length=10, limit=None)
if config.vectors:
nlp.vocab.from_disk(config.vectors / 'vocab')
return nlp
def initialize_pipeline(nlp, docs, golds, config):
print("Create parser")
nlp.add_pipe(nlp.create_pipe('parser'))
if config.multitask_tag:
nlp.parser.add_multitask_objective('tag')
if config.multitask_sent:
nlp.parser.add_multitask_objective('sent_start')
nlp.parser.moves.add_action(2, 'subtok')
nlp.add_pipe(nlp.create_pipe('tagger'))
@ -291,7 +289,6 @@ def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
for tag in gold.tags:
if tag is not None:
nlp.tagger.add_label(tag)
optimizer = nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds))
# Replace labels that didn't make the frequency cutoff
actions = set(nlp.parser.labels)
label_set = set([act.split('-')[1] for act in actions if '-' in act])
@ -299,38 +296,92 @@ def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
for i, label in enumerate(gold.labels):
if label is not None and label not in label_set:
gold.labels[i] = label.split('||')[0]
return nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds))
########################
# Command line helpers #
########################
@attr.s
class Config(object):
vectors = attr.ib(default=None)
max_doc_length = attr.ib(default=10)
multitask_tag = attr.ib(default=True)
multitask_sent = attr.ib(default=True)
nr_epoch = attr.ib(default=30)
batch_size = attr.ib(default=1000)
dropout = attr.ib(default=0.2)
@classmethod
def load(cls, loc):
with Path(loc).open('r', encoding='utf8') as file_:
cfg = json.load(file_)
return cls(**cfg)
class Dataset(object):
def __init__(self, path, section):
self.path = path
self.section = section
self.conllu = None
self.text = None
for file_path in self.path.iterdir():
name = file_path.parts[-1]
if section in name and name.endswith('conllu'):
self.conllu = file_path
elif section in name and name.endswith('txt'):
self.text = file_path
if self.conllu is None:
msg = "Could not find .txt file in {path} for {section}"
raise IOError(msg.format(section=section, path=path))
if self.text is None:
msg = "Could not find .txt file in {path} for {section}"
self.lang = self.conllu.parts[-1].split('-')[0].split('_')[0]
class TreebankPaths(object):
def __init__(self, ud_path, treebank, **cfg):
self.train = Dataset(ud_path / treebank, 'train')
self.dev = Dataset(ud_path / treebank, 'dev')
self.lang = self.train.lang
@plac.annotations(
ud_dir=("Path to Universal Dependencies corpus", "positional", None, Path),
config=("Path to json formatted config file", "positional", None, Config.load),
corpus=("UD corpus to train and evaluate on, e.g. en, es_ancora, etc",
"positional", None, str),
parses=("Path to write the development parses", "positional", None, Path)
)
def main(ud_dir, corpus, config, parses='/tmp/dev.conllu'):
paths = TreebankPaths(ud_dir, corpus)
nlp = load_nlp(paths.lang, config)
docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(),
config)
optimizer = initialize_pipeline(nlp, docs, golds, config)
n_train_words = sum(len(doc) for doc in docs)
print(n_train_words)
print("Begin training")
# Batch size starts at 1 and grows, so that we make updates quickly
# at the beginning of training.
batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 1),
spacy.util.env_opt('batch_to', 8),
spacy.util.env_opt('batch_compound', 1.001))
for i in range(30):
print("Begin training (%d words)" % n_train_words)
for i in range(config.nr_epoch):
docs = refresh_docs(docs)
batches = minibatch_by_words(list(zip(docs, golds)), size=1000)
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
batches = minibatch_by_words(list(zip(docs, golds)), size=config.batch_size)
losses = {}
for batch in batches:
for batch in tqdm.tqdm(batches, total=n_train_words//config.batch_size):
if not batch:
continue
batch_docs, batch_gold = zip(*batch)
nlp.update(batch_docs, batch_gold, sgd=optimizer,
drop=0.2, losses=losses)
pbar.update(sum(len(doc) for doc in batch_docs))
drop=config.dropout, losses=losses)
with nlp.use_params(optimizer.averages):
dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc,
oracle_segments=False, joint_sbd=True)
dev_docs, scorer = parse_dev_data(nlp, paths.dev.text, paths.dev.conllu,
**attr.asdict(config))
print_progress(i, losses, scorer)
with open(output_loc, 'w') as file_:
print_conllu(dev_docs, file_)
with open('/tmp/train.conllu', 'w') as file_:
print_conllu(list(nlp.pipe([d.text for d in batch_docs])), file_)
if __name__ == '__main__':