diff --git a/spacy/cli/ud_train.py b/spacy/cli/ud_train.py new file mode 100644 index 000000000..853cff9b3 --- /dev/null +++ b/spacy/cli/ud_train.py @@ -0,0 +1,371 @@ +'''Train for CONLL 2017 UD treebank evaluation. Takes .conllu files, writes +.conllu format for development data, allowing the official scorer to be used. +''' +from __future__ import unicode_literals +import plac +import tqdm +from pathlib import Path +import re +import sys +import json + +import spacy +import spacy.util +from ..tokens import Token, Doc +from ..gold import GoldParse +from ..util import compounding, minibatch_by_words +from ..syntax.nonproj import projectivize +from ..matcher import Matcher +from .. import displacy +from collections import defaultdict, Counter +from timeit import default_timer as timer + +import itertools +import random +import numpy.random +import cytoolz + +from . import conll17_ud_eval + +from .. import lang +from .. import lang +from ..lang import zh +from ..lang import ja + + +################ +# Data reading # +################ + +space_re = re.compile('\s+') +def split_text(text): + return [space_re.sub(' ', par.strip()) for par in text.split('\n\n')] + + +def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False, + max_doc_length=None, limit=None): + '''Read the CONLLU format into (Doc, GoldParse) tuples. If raw_text=True, + include Doc objects created using nlp.make_doc and then aligned against + the gold-standard sequences. If oracle_segments=True, include Doc objects + created from the gold-standard segments. At least one must be True.''' + if not raw_text and not oracle_segments: + raise ValueError("At least one of raw_text or oracle_segments must be True") + paragraphs = split_text(text_file.read()) + conllu = read_conllu(conllu_file) + # sd is spacy doc; cd is conllu doc + # cs is conllu sent, ct is conllu token + docs = [] + golds = [] + for doc_id, (text, cd) in enumerate(zip(paragraphs, conllu)): + sent_annots = [] + for cs in cd: + sent = defaultdict(list) + for id_, word, lemma, pos, tag, morph, head, dep, _, space_after in cs: + if '.' in id_: + continue + if '-' in id_: + continue + id_ = int(id_)-1 + head = int(head)-1 if head != '0' else id_ + sent['words'].append(word) + sent['tags'].append(tag) + sent['heads'].append(head) + sent['deps'].append('ROOT' if dep == 'root' else dep) + sent['spaces'].append(space_after == '_') + sent['entities'] = ['-'] * len(sent['words']) + sent['heads'], sent['deps'] = projectivize(sent['heads'], + sent['deps']) + if oracle_segments: + docs.append(Doc(nlp.vocab, words=sent['words'], spaces=sent['spaces'])) + golds.append(GoldParse(docs[-1], **sent)) + + sent_annots.append(sent) + if raw_text and max_doc_length and len(sent_annots) >= max_doc_length: + doc, gold = _make_gold(nlp, None, sent_annots) + sent_annots = [] + docs.append(doc) + golds.append(gold) + if limit and len(docs) >= limit: + return docs, golds + + if raw_text and sent_annots: + doc, gold = _make_gold(nlp, None, sent_annots) + docs.append(doc) + golds.append(gold) + if limit and len(docs) >= limit: + return docs, golds + return docs, golds + + +def read_conllu(file_): + docs = [] + sent = [] + doc = [] + for line in file_: + if line.startswith('# newdoc'): + if doc: + docs.append(doc) + doc = [] + elif line.startswith('#'): + continue + elif not line.strip(): + if sent: + doc.append(sent) + sent = [] + else: + sent.append(list(line.strip().split('\t'))) + if len(sent[-1]) != 10: + print(repr(line)) + raise ValueError + if sent: + doc.append(sent) + if doc: + docs.append(doc) + return docs + + +def _make_gold(nlp, text, sent_annots): + # Flatten the conll annotations, and adjust the head indices + flat = defaultdict(list) + for sent in sent_annots: + flat['heads'].extend(len(flat['words'])+head for head in sent['heads']) + for field in ['words', 'tags', 'deps', 'entities', 'spaces']: + flat[field].extend(sent[field]) + # Construct text if necessary + assert len(flat['words']) == len(flat['spaces']) + if text is None: + text = ''.join(word+' '*space for word, space in zip(flat['words'], flat['spaces'])) + doc = nlp.make_doc(text) + flat.pop('spaces') + gold = GoldParse(doc, **flat) + return doc, gold + +############################# +# Data transforms for spaCy # +############################# + +def golds_to_gold_tuples(docs, golds): + '''Get out the annoying 'tuples' format used by begin_training, given the + GoldParse objects.''' + tuples = [] + for doc, gold in zip(docs, golds): + text = doc.text + ids, words, tags, heads, labels, iob = zip(*gold.orig_annot) + sents = [((ids, words, tags, heads, labels, iob), [])] + tuples.append((text, sents)) + return tuples + + +############## +# Evaluation # +############## + +def evaluate(nlp, text_loc, gold_loc, sys_loc, limit=None): + with text_loc.open('r', encoding='utf8') as text_file: + texts = split_text(text_file.read()) + docs = list(nlp.pipe(texts)) + with sys_loc.open('w', encoding='utf8') as out_file: + write_conllu(docs, out_file) + with gold_loc.open('r', encoding='utf8') as gold_file: + gold_ud = conll17_ud_eval.load_conllu(gold_file) + with sys_loc.open('r', encoding='utf8') as sys_file: + sys_ud = conll17_ud_eval.load_conllu(sys_file) + scores = conll17_ud_eval.evaluate(gold_ud, sys_ud) + return docs, scores + + +def write_conllu(docs, file_): + merger = Matcher(docs[0].vocab) + merger.add('SUBTOK', None, [{'DEP': 'subtok', 'op': '+'}]) + for i, doc in enumerate(docs): + matches = merger(doc) + spans = [doc[start:end+1] for _, start, end in matches] + offsets = [(span.start_char, span.end_char) for span in spans] + for start_char, end_char in offsets: + doc.merge(start_char, end_char) + file_.write("# newdoc id = {i}\n".format(i=i)) + for j, sent in enumerate(doc.sents): + file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j)) + file_.write("# text = {text}\n".format(text=sent.text)) + for k, token in enumerate(sent): + file_.write(token._.get_conllu_lines(k) + '\n') + file_.write('\n') + + +def print_progress(itn, losses, ud_scores): + fields = { + 'dep_loss': losses.get('parser', 0.0), + 'tag_loss': losses.get('tagger', 0.0), + 'words': ud_scores['Words'].f1 * 100, + 'sents': ud_scores['Sentences'].f1 * 100, + 'tags': ud_scores['XPOS'].f1 * 100, + 'uas': ud_scores['UAS'].f1 * 100, + 'las': ud_scores['LAS'].f1 * 100, + } + header = ['Epoch', 'Loss', 'LAS', 'UAS', 'TAG', 'SENT', 'WORD'] + if itn == 0: + print('\t'.join(header)) + tpl = '\t'.join(( + '{:d}', + '{dep_loss:.1f}', + '{las:.1f}', + '{uas:.1f}', + '{tags:.1f}', + '{sents:.1f}', + '{words:.1f}', + )) + print(tpl.format(itn, **fields)) + +#def get_sent_conllu(sent, sent_id): +# lines = ["# sent_id = {sent_id}".format(sent_id=sent_id)] + +def get_token_conllu(token, i): + if token._.begins_fused: + n = 1 + while token.nbor(n)._.inside_fused: + n += 1 + id_ = '%d-%d' % (i, i+n) + lines = [id_, token.text, '_', '_', '_', '_', '_', '_', '_', '_'] + else: + lines = [] + if token.head.i == token.i: + head = 0 + else: + head = i + (token.head.i - token.i) + 1 + fields = [str(i+1), token.text, token.lemma_, token.pos_, token.tag_, '_', + str(head), token.dep_.lower(), '_', '_'] + lines.append('\t'.join(fields)) + return '\n'.join(lines) + +Token.set_extension('get_conllu_lines', method=get_token_conllu) +Token.set_extension('begins_fused', default=False) +Token.set_extension('inside_fused', default=False) + + +################## +# Initialization # +################## + + +def load_nlp(corpus, config): + lang = corpus.split('_')[0] + nlp = spacy.blank(lang) + if config.vectors: + nlp.vocab.from_disk(Path(config.vectors) / 'vocab') + return nlp + +def initialize_pipeline(nlp, docs, golds, config, device): + nlp.add_pipe(nlp.create_pipe('parser')) + if config.multitask_tag: + nlp.parser.add_multitask_objective('tag') + if config.multitask_sent: + nlp.parser.add_multitask_objective('sent_start') + nlp.add_pipe(nlp.create_pipe('tagger')) + for gold in golds: + for tag in gold.tags: + if tag is not None: + nlp.tagger.add_label(tag) + return nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds), device=device) + + +######################## +# Command line helpers # +######################## + +class Config(object): + def __init__(self, vectors=None, max_doc_length=10, multitask_tag=True, + multitask_sent=True, nr_epoch=30, batch_size=1000, dropout=0.2): + for key, value in locals().items(): + setattr(self, key, value) + + @classmethod + def load(cls, loc): + with Path(loc).open('r', encoding='utf8') as file_: + cfg = json.load(file_) + return cls(**cfg) + + +class Dataset(object): + def __init__(self, path, section): + self.path = path + self.section = section + self.conllu = None + self.text = None + for file_path in self.path.iterdir(): + name = file_path.parts[-1] + if section in name and name.endswith('conllu'): + self.conllu = file_path + elif section in name and name.endswith('txt'): + self.text = file_path + if self.conllu is None: + msg = "Could not find .txt file in {path} for {section}" + raise IOError(msg.format(section=section, path=path)) + if self.text is None: + msg = "Could not find .txt file in {path} for {section}" + self.lang = self.conllu.parts[-1].split('-')[0].split('_')[0] + + +class TreebankPaths(object): + def __init__(self, ud_path, treebank, **cfg): + self.train = Dataset(ud_path / treebank, 'train') + self.dev = Dataset(ud_path / treebank, 'dev') + self.lang = self.train.lang + + +@plac.annotations( + ud_dir=("Path to Universal Dependencies corpus", "positional", None, Path), + corpus=("UD corpus to train and evaluate on, e.g. en, es_ancora, etc", + "positional", None, str), + parses_dir=("Directory to write the development parses", "positional", None, Path), + config=("Path to json formatted config file", "positional"), + limit=("Size limit", "option", "n", int), + use_gpu=("Use GPU", "option", "g", int) +) +def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1): + spacy.util.fix_random_seed() + lang.zh.Chinese.Defaults.use_jieba = False + lang.ja.Japanese.Defaults.use_janome = False + + config = Config.load(config) + paths = TreebankPaths(ud_dir, corpus) + if not (parses_dir / corpus).exists(): + (parses_dir / corpus).mkdir() + print("Train and evaluate", corpus, "using lang", paths.lang) + nlp = load_nlp(paths.lang, config) + + docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(), + max_doc_length=config.max_doc_length, limit=limit) + + optimizer = initialize_pipeline(nlp, docs, golds, config, use_gpu) + + batch_sizes = compounding(config.batch_size//10, config.batch_size, 1.001) + for i in range(config.nr_epoch): + docs = [nlp.make_doc(doc.text) for doc in docs] + Xs = list(zip(docs, golds)) + random.shuffle(Xs) + batches = minibatch_by_words(Xs, size=batch_sizes) + losses = {} + n_train_words = sum(len(doc) for doc in docs) + with tqdm.tqdm(total=n_train_words, leave=False) as pbar: + for batch in batches: + batch_docs, batch_gold = zip(*batch) + pbar.update(sum(len(doc) for doc in batch_docs)) + nlp.update(batch_docs, batch_gold, sgd=optimizer, + drop=config.dropout, losses=losses) + + out_path = parses_dir / corpus / 'epoch-{i}.conllu'.format(i=i) + with nlp.use_params(optimizer.averages): + parsed_docs, scores = evaluate(nlp, paths.dev.text, paths.dev.conllu, out_path) + print_progress(i, losses, scores) + _render_parses(i, parsed_docs[:50]) + + +def _render_parses(i, to_render): + to_render[0].user_data['title'] = "Batch %d" % i + with Path('/tmp/parses.html').open('w') as file_: + html = displacy.render(to_render[:5], style='dep', page=True) + file_.write(html) + + +if __name__ == '__main__': + plac.call(main) diff --git a/spacy/language.py b/spacy/language.py index f04da7d30..65ec9ba5a 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -462,7 +462,7 @@ class Language(object): self._optimizer = sgd for name, proc in self.pipeline: if hasattr(proc, 'begin_training'): - proc.begin_training(get_gold_tuples(), + proc.begin_training(get_gold_tuples, pipeline=self.pipeline, sgd=self._optimizer, **cfg) diff --git a/spacy/pipeline.pyx b/spacy/pipeline.pyx index 743f6ac85..43a8896cc 100644 --- a/spacy/pipeline.pyx +++ b/spacy/pipeline.pyx @@ -172,7 +172,7 @@ class Pipe(object): return create_default_optimizer(self.model.ops, **self.cfg.get('optimizer', {})) - def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None, + def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None, **kwargs): """Initialize the pipe for training, using data exampes if available. If no model has been initialized yet, the model is added.""" @@ -374,7 +374,7 @@ class Tensorizer(Pipe): loss = (d_scores**2).sum() return loss, d_scores - def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None, + def begin_training(self, gold_tuples=lambda: [], pipeline=None, sgd=None, **kwargs): """Allocate models, pre-process training data and acquire an optimizer. @@ -498,11 +498,11 @@ class Tagger(Pipe): d_scores = self.model.ops.unflatten(d_scores, [len(d) for d in docs]) return float(loss), d_scores - def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None, + def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None, **kwargs): orig_tag_map = dict(self.vocab.morphology.tag_map) new_tag_map = OrderedDict() - for raw_text, annots_brackets in gold_tuples: + for raw_text, annots_brackets in get_gold_tuples(): for annots, brackets in annots_brackets: ids, words, tags, heads, deps, ents = annots for tag in tags: @@ -673,9 +673,9 @@ class MultitaskObjective(Tagger): def set_annotations(self, docs, dep_ids, tensors=None): pass - def begin_training(self, gold_tuples=tuple(), pipeline=None, tok2vec=None, + def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, tok2vec=None, sgd=None, **kwargs): - gold_tuples = nonproj.preprocess_training_data(gold_tuples) + gold_tuples = nonproj.preprocess_training_data(get_gold_tuples()) for raw_text, annots_brackets in gold_tuples: for annots, brackets in annots_brackets: ids, words, tags, heads, deps, ents = annots @@ -898,7 +898,7 @@ class TextCategorizer(Pipe): self.labels.append(label) return 1 - def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None): + def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None): if pipeline and getattr(pipeline[0], 'name', None) == 'tensorizer': token_vector_width = pipeline[0].model.nO else: @@ -925,10 +925,10 @@ cdef class DependencyParser(Parser): labeller = MultitaskObjective(self.vocab, target=target) self._multitasks.append(labeller) - def init_multitask_objectives(self, gold_tuples, pipeline, sgd=None, **cfg): + def init_multitask_objectives(self, get_gold_tuples, pipeline, sgd=None, **cfg): for labeller in self._multitasks: tok2vec = self.model[0] - labeller.begin_training(gold_tuples, pipeline=pipeline, + labeller.begin_training(get_gold_tuples, pipeline=pipeline, tok2vec=tok2vec, sgd=sgd) def __reduce__(self): @@ -946,10 +946,10 @@ cdef class EntityRecognizer(Parser): labeller = MultitaskObjective(self.vocab, target=target) self._multitasks.append(labeller) - def init_multitask_objectives(self, gold_tuples, pipeline, sgd=None, **cfg): + def init_multitask_objectives(self, get_gold_tuples, pipeline, sgd=None, **cfg): for labeller in self._multitasks: tok2vec = self.model[0] - labeller.begin_training(gold_tuples, pipeline=pipeline, + labeller.begin_training(get_gold_tuples, pipeline=pipeline, tok2vec=tok2vec) def __reduce__(self): diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index f8cd964ef..4e2124182 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -164,16 +164,17 @@ cdef void sum_state_features(float* output, cdef const float* feature padding = cached cached += F * O + cdef int id_stride = F*O + cdef float one = 1. for b in range(B): for f in range(F): if token_ids[f] < 0: feature = &padding[f*O] else: - idx = token_ids[f] * F * O + f*O + idx = token_ids[f] * id_stride + f*O feature = &cached[idx] - for i in range(O): - output[i] += feature[i] - output += O + openblas.simple_axpy(&output[b*O], O, + feature, one) token_ids += F @@ -726,7 +727,7 @@ cdef class Parser: lower, stream, drop=0.0) return (tokvecs, bp_tokvecs), state2vec, upper - nr_feature = 13 + nr_feature = 8 def get_token_ids(self, states): cdef StateClass state @@ -821,15 +822,13 @@ cdef class Parser: copy_array(larger.b[:smaller.nO], smaller.b) self.model[-1]._layers[-1] = larger - def begin_training(self, gold_tuples, pipeline=None, sgd=None, **cfg): + def begin_training(self, get_gold_tuples, pipeline=None, sgd=None, **cfg): if 'model' in cfg: self.model = cfg['model'] - gold_tuples = nonproj.preprocess_training_data(gold_tuples, - label_freq_cutoff=100) - actions = self.moves.get_actions(gold_parses=gold_tuples) - for action, labels in actions.items(): - for label in labels: - self.moves.add_action(action, label) + cfg.setdefault('min_action_freq', 30) + actions = self.moves.get_actions(gold_parses=get_gold_tuples(), + min_freq=cfg.get('min_action_freq', 30)) + self.moves.initialize_actions(actions) cfg.setdefault('token_vector_width', 128) if self.model is True: cfg['pretrained_dims'] = self.vocab.vectors_length @@ -837,9 +836,9 @@ cdef class Parser: if sgd is None: sgd = self.create_optimizer() self.model[1].begin_training( - self.model[1].ops.allocate((5, cfg['token_vector_width']))) + self.model[1].ops.allocate((5, cfg['token_vector_width']))) if pipeline is not None: - self.init_multitask_objectives(gold_tuples, pipeline, sgd=sgd, **cfg) + self.init_multitask_objectives(get_gold_tuples, pipeline, sgd=sgd, **cfg) link_vectors_to_models(self.vocab) else: if sgd is None: @@ -853,7 +852,7 @@ cdef class Parser: # Defined in subclasses, to avoid circular import raise NotImplementedError - def init_multitask_objectives(self, gold_tuples, pipeline, **cfg): + def init_multitask_objectives(self, get_gold_tuples, pipeline, **cfg): '''Setup models for secondary objectives, to benefit from multi-task learning. This method is intended to be overridden by subclasses.