diff --git a/examples/training/conllu.py b/examples/training/conllu.py index 148475bbc..f60308980 100644 --- a/examples/training/conllu.py +++ b/examples/training/conllu.py @@ -4,8 +4,12 @@ from __future__ import unicode_literals import plac import tqdm +import attr +from pathlib import Path import re import sys +import json + import spacy import spacy.util from spacy.tokens import Token, Doc @@ -40,32 +44,9 @@ def minibatch_by_words(items, size=5000): batch.append((doc, gold)) yield batch - -def get_token_acc(docs, golds): - '''Quick function to evaluate tokenization accuracy.''' - miss = 0 - hit = 0 - for doc, gold in zip(docs, golds): - for i in range(len(doc)): - token = doc[i] - align = gold.words[i] - if align == None: - miss += 1 - else: - hit += 1 - return miss, hit - - -def golds_to_gold_tuples(docs, golds): - '''Get out the annoying 'tuples' format used by begin_training, given the - GoldParse objects.''' - tuples = [] - for doc, gold in zip(docs, golds): - text = doc.text - ids, words, tags, heads, labels, iob = zip(*gold.orig_annot) - sents = [((ids, words, tags, heads, labels, iob), [])] - tuples.append((text, sents)) - return tuples +################ +# Data reading # +################ def split_text(text): return [par.strip().replace('\n', ' ') @@ -127,34 +108,6 @@ def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False, return docs, golds -def _make_gold(nlp, text, sent_annots): - # Flatten the conll annotations, and adjust the head indices - flat = defaultdict(list) - for sent in sent_annots: - flat['heads'].extend(len(flat['words'])+head for head in sent['heads']) - for field in ['words', 'tags', 'deps', 'entities', 'spaces']: - flat[field].extend(sent[field]) - # Construct text if necessary - assert len(flat['words']) == len(flat['spaces']) - if text is None: - text = ''.join(word+' '*space for word, space in zip(flat['words'], flat['spaces'])) - doc = nlp.make_doc(text) - flat.pop('spaces') - gold = GoldParse(doc, **flat) - #for annot in gold.orig_annot: - # print(annot) - #for i in range(len(doc)): - # print(doc[i].text, gold.words[i], gold.labels[i], gold.heads[i]) - return doc, gold - - -def refresh_docs(docs): - vocab = docs[0].vocab - return [Doc(vocab, words=[t.text for t in doc], - spaces=[t.whitespace_ for t in doc]) - for doc in docs] - - def read_conllu(file_): docs = [] sent = [] @@ -179,6 +132,52 @@ def read_conllu(file_): return docs +def _make_gold(nlp, text, sent_annots): + # Flatten the conll annotations, and adjust the head indices + flat = defaultdict(list) + for sent in sent_annots: + flat['heads'].extend(len(flat['words'])+head for head in sent['heads']) + for field in ['words', 'tags', 'deps', 'entities', 'spaces']: + flat[field].extend(sent[field]) + # Construct text if necessary + assert len(flat['words']) == len(flat['spaces']) + if text is None: + text = ''.join(word+' '*space for word, space in zip(flat['words'], flat['spaces'])) + doc = nlp.make_doc(text) + flat.pop('spaces') + gold = GoldParse(doc, **flat) + #for annot in gold.orig_annot: + # print(annot) + #for i in range(len(doc)): + # print(doc[i].text, gold.words[i], gold.labels[i], gold.heads[i]) + return doc, gold + +############################# +# Data transforms for spaCy # +############################# + +def golds_to_gold_tuples(docs, golds): + '''Get out the annoying 'tuples' format used by begin_training, given the + GoldParse objects.''' + tuples = [] + for doc, gold in zip(docs, golds): + text = doc.text + ids, words, tags, heads, labels, iob = zip(*gold.orig_annot) + sents = [((ids, words, tags, heads, labels, iob), [])] + tuples.append((text, sents)) + return tuples + + +def refresh_docs(docs): + vocab = docs[0].vocab + return [Doc(vocab, words=[t.text for t in doc], + spaces=[t.whitespace_ for t in doc]) + for doc in docs] + +############## +# Evaluation # +############## + def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False, joint_sbd=True, limit=None): with open(text_loc) as text_file: @@ -265,33 +264,31 @@ Token.set_extension('begins_fused', default=False) Token.set_extension('inside_fused', default=False) -def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc, - output_loc): - if lang == 'en': - nlp = spacy.blank(lang) - vec_nlp = spacy.util.load_model('spacy/data/en_core_web_lg/en_core_web_lg-2.0.0') - nlp.vocab.vectors = vec_nlp.vocab.vectors - for lex in vec_nlp.vocab: - _ = nlp.vocab[lex.orth_] - vec_nlp = None - else: - nlp = spacy.load(lang) - with open(conllu_train_loc) as conllu_file: - with open(text_train_loc) as text_file: - docs, golds = read_data(nlp, conllu_file, text_file, - oracle_segments=False, raw_text=True, - max_doc_length=10, limit=None) +################## +# Initialization # +################## + + +def load_nlp(corpus, config): + lang = corpus.split('_')[0] + nlp = spacy.blank(lang) + if config.vectors: + nlp.vocab.from_disk(config.vectors / 'vocab') + return nlp + +def initialize_pipeline(nlp, docs, golds, config): print("Create parser") nlp.add_pipe(nlp.create_pipe('parser')) - nlp.parser.add_multitask_objective('tag') - nlp.parser.add_multitask_objective('sent_start') + if config.multitask_tag: + nlp.parser.add_multitask_objective('tag') + if config.multitask_sent: + nlp.parser.add_multitask_objective('sent_start') nlp.parser.moves.add_action(2, 'subtok') nlp.add_pipe(nlp.create_pipe('tagger')) for gold in golds: for tag in gold.tags: if tag is not None: nlp.tagger.add_label(tag) - optimizer = nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds)) # Replace labels that didn't make the frequency cutoff actions = set(nlp.parser.labels) label_set = set([act.split('-')[1] for act in actions if '-' in act]) @@ -299,38 +296,92 @@ def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc, for i, label in enumerate(gold.labels): if label is not None and label not in label_set: gold.labels[i] = label.split('||')[0] - n_train_words = sum(len(doc) for doc in docs) - print(n_train_words) - print("Begin training") - # Batch size starts at 1 and grows, so that we make updates quickly - # at the beginning of training. - batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 1), - spacy.util.env_opt('batch_to', 8), - spacy.util.env_opt('batch_compound', 1.001)) - for i in range(30): - docs = refresh_docs(docs) - batches = minibatch_by_words(list(zip(docs, golds)), size=1000) - with tqdm.tqdm(total=n_train_words, leave=False) as pbar: - losses = {} - for batch in batches: - if not batch: - continue - batch_docs, batch_gold = zip(*batch) + return nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds)) - nlp.update(batch_docs, batch_gold, sgd=optimizer, - drop=0.2, losses=losses) - pbar.update(sum(len(doc) for doc in batch_docs)) + +######################## +# Command line helpers # +######################## + +@attr.s +class Config(object): + vectors = attr.ib(default=None) + max_doc_length = attr.ib(default=10) + multitask_tag = attr.ib(default=True) + multitask_sent = attr.ib(default=True) + nr_epoch = attr.ib(default=30) + batch_size = attr.ib(default=1000) + dropout = attr.ib(default=0.2) + + @classmethod + def load(cls, loc): + with Path(loc).open('r', encoding='utf8') as file_: + cfg = json.load(file_) + return cls(**cfg) + + +class Dataset(object): + def __init__(self, path, section): + self.path = path + self.section = section + self.conllu = None + self.text = None + for file_path in self.path.iterdir(): + name = file_path.parts[-1] + if section in name and name.endswith('conllu'): + self.conllu = file_path + elif section in name and name.endswith('txt'): + self.text = file_path + if self.conllu is None: + msg = "Could not find .txt file in {path} for {section}" + raise IOError(msg.format(section=section, path=path)) + if self.text is None: + msg = "Could not find .txt file in {path} for {section}" + self.lang = self.conllu.parts[-1].split('-')[0].split('_')[0] + + +class TreebankPaths(object): + def __init__(self, ud_path, treebank, **cfg): + self.train = Dataset(ud_path / treebank, 'train') + self.dev = Dataset(ud_path / treebank, 'dev') + self.lang = self.train.lang + + +@plac.annotations( + ud_dir=("Path to Universal Dependencies corpus", "positional", None, Path), + config=("Path to json formatted config file", "positional", None, Config.load), + corpus=("UD corpus to train and evaluate on, e.g. en, es_ancora, etc", + "positional", None, str), + parses=("Path to write the development parses", "positional", None, Path) +) +def main(ud_dir, corpus, config, parses='/tmp/dev.conllu'): + paths = TreebankPaths(ud_dir, corpus) + nlp = load_nlp(paths.lang, config) + + docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(), + config) + + optimizer = initialize_pipeline(nlp, docs, golds, config) + n_train_words = sum(len(doc) for doc in docs) + print("Begin training (%d words)" % n_train_words) + for i in range(config.nr_epoch): + docs = refresh_docs(docs) + batches = minibatch_by_words(list(zip(docs, golds)), size=config.batch_size) + losses = {} + for batch in tqdm.tqdm(batches, total=n_train_words//config.batch_size): + if not batch: + continue + batch_docs, batch_gold = zip(*batch) + + nlp.update(batch_docs, batch_gold, sgd=optimizer, + drop=config.dropout, losses=losses) with nlp.use_params(optimizer.averages): - dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc, - oracle_segments=False, joint_sbd=True) + dev_docs, scorer = parse_dev_data(nlp, paths.dev.text, paths.dev.conllu, + **attr.asdict(config)) print_progress(i, losses, scorer) with open(output_loc, 'w') as file_: print_conllu(dev_docs, file_) - with open('/tmp/train.conllu', 'w') as file_: - print_conllu(list(nlp.pipe([d.text for d in batch_docs])), file_) - - if __name__ == '__main__':