From 8e48b58cd6322f880febd2d192eb5197e7bd87d6 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sat, 6 May 2017 16:47:15 +0200 Subject: [PATCH] Gradients look correct --- bin/parser/train_ud.py | 86 ++++++++++++++++++++--------- spacy/_ml.py | 34 +++++++++++- spacy/syntax/parser.pyx | 106 ++++++++++++++++++++---------------- spacy/syntax/stateclass.pyx | 35 +++++++----- 4 files changed, 173 insertions(+), 88 deletions(-) diff --git a/bin/parser/train_ud.py b/bin/parser/train_ud.py index afc4491cb..e9dc2a443 100644 --- a/bin/parser/train_ud.py +++ b/bin/parser/train_ud.py @@ -1,4 +1,4 @@ -from __future__ import unicode_literals +from __future__ import unicode_literals, print_function import plac import json import random @@ -9,7 +9,7 @@ from spacy.syntax.nonproj import PseudoProjectivity from spacy.language import Language from spacy.gold import GoldParse from spacy.tagger import Tagger -from spacy.pipeline import DependencyParser, BeamDependencyParser +from spacy.pipeline import DependencyParser, TokenVectorEncoder from spacy.syntax.parser import get_templates from spacy.syntax.arc_eager import ArcEager from spacy.scorer import Scorer @@ -36,10 +36,10 @@ def read_conllx(loc, n=0): try: id_ = int(id_) - 1 head = (int(head) - 1) if head != '0' else id_ - dep = 'ROOT' if dep == 'root' else dep - tokens.append((id_, word, tag, head, dep, 'O')) + dep = 'ROOT' if dep == 'root' else 'unlabelled' + # Hack for efficiency + tokens.append((id_, word, pos+'__'+morph, head, dep, 'O')) except: - print(line) raise tuples = [list(t) for t in zip(*tokens)] yield (None, [[tuples, []]]) @@ -48,19 +48,37 @@ def read_conllx(loc, n=0): break -def score_model(vocab, tagger, parser, gold_docs, verbose=False): +def score_model(vocab, encoder, tagger, parser, Xs, ys, verbose=False): scorer = Scorer() - for _, gold_doc in gold_docs: - for (ids, words, tags, heads, deps, entities), _ in gold_doc: - doc = Doc(vocab, words=words) - tagger(doc) - parser(doc) - PseudoProjectivity.deprojectivize(doc) - gold = GoldParse(doc, tags=tags, heads=heads, deps=deps) - scorer.score(doc, gold, verbose=verbose) + correct = 0. + total = 0. + for doc, gold in zip(Xs, ys): + doc = Doc(vocab, words=[w.text for w in doc]) + encoder(doc) + tagger(doc) + parser(doc) + PseudoProjectivity.deprojectivize(doc) + scorer.score(doc, gold, verbose=verbose) + for token, tag in zip(doc, gold.tags): + univ_guess, _ = token.tag_.split('_', 1) + univ_truth, _ = tag.split('_', 1) + correct += univ_guess == univ_truth + total += 1 return scorer +def organize_data(vocab, train_sents): + Xs = [] + ys = [] + for _, doc_sents in train_sents: + for (ids, words, tags, heads, deps, ner), _ in doc_sents: + doc = Doc(vocab, words=words) + gold = GoldParse(doc, tags=tags, heads=heads, deps=deps) + Xs.append(doc) + ys.append(gold) + return Xs, ys + + def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None): LangClass = spacy.util.get_lang_class(lang_name) train_sents = list(read_conllx(train_loc)) @@ -114,21 +132,37 @@ def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None): for tag in tags: assert tag in vocab.morphology.tag_map, repr(tag) tagger = Tagger(vocab) + encoder = TokenVectorEncoder(vocab) parser = DependencyParser(vocab, actions=actions, features=features, L1=0.0) - for itn in range(30): - loss = 0. - for _, doc_sents in train_sents: - for (ids, words, tags, heads, deps, ner), _ in doc_sents: - doc = Doc(vocab, words=words) - gold = GoldParse(doc, tags=tags, heads=heads, deps=deps) - tagger(doc) - loss += parser.update(doc, gold, itn=itn) - doc = Doc(vocab, words=words) + + Xs, ys = organize_data(vocab, train_sents) + Xs = Xs[:1] + ys = ys[:1] + with encoder.model.begin_training(Xs[:100], ys[:100]) as (trainer, optimizer): + docs = list(Xs) + for doc in docs: + encoder(doc) + parser.begin_training(docs, ys) + nn_loss = [0.] + def track_progress(): + scorer = score_model(vocab, encoder, tagger, parser, Xs, ys) + itn = len(nn_loss) + print('%d:\t%.3f\t%.3f\t%.3f' % (itn, nn_loss[-1], scorer.uas, scorer.tags_acc)) + nn_loss.append(0.) + trainer.each_epoch.append(track_progress) + trainer.batch_size = 1 + trainer.nb_epoch = 100 + for docs, golds in trainer.iterate(Xs, ys, progress_bar=False): + docs = [Doc(vocab, words=[w.text for w in doc]) for doc in docs] + tokvecs, upd_tokvecs = encoder.begin_update(docs) + for doc, tokvec in zip(docs, tokvecs): + doc.tensor = tokvec + for doc, gold in zip(docs, golds): tagger.update(doc, gold) - random.shuffle(train_sents) - scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc)) - print('%d:\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas, scorer.tags_acc)) + d_tokvecs, loss = parser.update(docs, golds, sgd=optimizer) + upd_tokvecs(d_tokvecs, sgd=optimizer) + nn_loss[-1] += loss nlp = LangClass(vocab=vocab, tagger=tagger, parser=parser) nlp.end_training(model_dir) scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc)) diff --git a/spacy/_ml.py b/spacy/_ml.py index bb228da3f..460c2a2c8 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -1,5 +1,5 @@ from thinc.api import layerize, chain, clone, concatenate, with_flatten -from thinc.neural import Model, Maxout, Softmax +from thinc.neural import Model, Maxout, Softmax, Affine from thinc.neural._classes.hash_embed import HashEmbed from thinc.neural._classes.convolution import ExtractWindow @@ -21,11 +21,41 @@ def build_model(state2vec, width, depth, nr_class): state2vec >> Maxout(width, 1344) >> Maxout(width, width) - >> Softmax(nr_class, width) + >> Affine(nr_class, width) ) return model +def build_debug_model(state2vec, width, depth, nr_class): + with Model.define_operators({'>>': chain, '**': clone}): + model = ( + state2vec + >> Maxout(width) + >> Affine(nr_class) + ) + return model + + + +def build_debug_state2vec(width, nr_vector=1000, nF=1, nB=0, nS=1, nL=2, nR=2): + ops = Model.ops + def forward(tokens_attrs_vectors, drop=0.): + tokens, attr_vals, tokvecs = tokens_attrs_vectors + + orig_tokvecs_shape = tokvecs.shape + tokvecs = tokvecs.reshape((tokvecs.shape[0], tokvecs.shape[1] * + tokvecs.shape[2])) + + vector = tokvecs + + def backward(d_vector, sgd=None): + d_tokvecs = vector.reshape(orig_tokvecs_shape) + return (tokens, d_tokvecs) + return vector, backward + model = layerize(forward) + return model + + def build_parser_state2vec(width, nr_vector=1000, nF=1, nB=0, nS=1, nL=2, nR=2): embed_tags = _reshape(chain(get_col(0), HashEmbed(16, nr_vector))) embed_deps = _reshape(chain(get_col(1), HashEmbed(16, nr_vector))) diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index 67643617a..bd1c1650b 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -28,6 +28,8 @@ from murmurhash.mrmr cimport hash64 from preshed.maps cimport MapStruct from preshed.maps cimport map_get +from numpy import exp + from . import _parse_features from ._parse_features cimport CONTEXT_SIZE from ._parse_features cimport fill_context @@ -43,6 +45,7 @@ from ..gold cimport GoldParse from ..attrs cimport TAG, DEP from .._ml import build_parser_state2vec, build_model +from .._ml import build_debug_state2vec, build_debug_model USE_FTRL = True @@ -111,8 +114,8 @@ cdef class Parser: return (Parser, (self.vocab, self.moves, self.model), None, None) def build_model(self, width=8, nr_vector=1000, nF=1, nB=1, nS=1, nL=1, nR=1, **_): - state2vec = build_parser_state2vec(width, nr_vector, nF, nB, nL, nR) - model = build_model(state2vec, width, 2, self.moves.n_moves) + state2vec = build_debug_state2vec(width, nr_vector, nF, nB, nL, nR) + model = build_debug_model(state2vec, width, 2, self.moves.n_moves) return model def __call__(self, Doc tokens): @@ -166,32 +169,22 @@ cdef class Parser: cdef Doc doc cdef StateClass state cdef int guess - is_valid = self.model.ops.allocate((len(docs), nr_class), dtype='i') tokvecs = [d.tensor for d in docs] - attr_names = self.model.ops.allocate((2,), dtype='i') - attr_names[0] = TAG - attr_names[1] = DEP all_states = list(states) todo = zip(states, tokvecs) while todo: states, tokvecs = zip(*todo) - features = self._get_features(states, tokvecs, attr_names) - scores = self.model.predict(features) - self._validate_batch(is_valid, states) - scores *= is_valid + scores, _ = self._begin_update(states, tokvecs) for state, guess in zip(states, scores.argmax(axis=1)): action = self.moves.c[guess] action.do(state.c, action.label) - todo = filter(lambda sp: not sp[0].is_final(), todo) + todo = filter(lambda sp: not sp[0].py_is_final(), todo) for state, doc in zip(all_states, docs): self.moves.finalize_state(state.c) for i in range(doc.length): doc.c[i] = state.c._sent[i] - - def update(self, docs, golds, drop=0., sgd=None): - if isinstance(docs, Doc) and isinstance(golds, GoldParse): - return self.update([docs], [golds], drop=drop) + def begin_training(self, docs, golds): for gold in golds: self.moves.preprocess_gold(gold) states = self._init_states(docs) @@ -204,39 +197,60 @@ cdef class Parser: attr_names = self.model.ops.allocate((2,), dtype='i') attr_names[0] = TAG attr_names[1] = DEP + + features = self._get_features(states, tokvecs, attr_names) + self.model.begin_training(features) + + + def update(self, docs, golds, drop=0., sgd=None): + if isinstance(docs, Doc) and isinstance(golds, GoldParse): + return self.update([docs], [golds], drop=drop) + for gold in golds: + self.moves.preprocess_gold(gold) + states = self._init_states(docs) + tokvecs = [d.tensor for d in docs] + d_tokens = [self.model.ops.allocate(d.tensor.shape) for d in docs] + nr_class = self.moves.n_moves output = list(d_tokens) todo = zip(states, tokvecs, golds, d_tokens) assert len(states) == len(todo) loss = 0. while todo: states, tokvecs, golds, d_tokens = zip(*todo) - features = self._get_features(states, tokvecs, attr_names) - - scores, finish_update = self.model.begin_update(features, drop=drop) - assert scores.shape == (len(states), self.moves.n_moves), (len(states), scores.shape) - - self._cost_batch(costs, is_valid, states, golds) - scores *= is_valid - self._set_gradient(gradients, scores, costs) - loss += numpy.abs(gradients).sum() / gradients.shape[0] - - token_ids, batch_token_grads = finish_update(gradients, sgd=sgd) + scores, finish_update = self._begin_update(states, tokvecs) + token_ids, batch_token_grads = finish_update(golds, sgd=sgd) for i, tok_i in enumerate(token_ids): d_tokens[i][tok_i] += batch_token_grads[i] self._transition_batch(states, scores) # Get unfinished states (and their matching gold and token gradients) - todo = filter(lambda sp: not sp[0].is_final(), todo) - costs = costs[:len(todo)] - is_valid = is_valid[:len(todo)] - gradients = gradients[:len(todo)] - - gradients.fill(0) - costs.fill(0) - is_valid.fill(1) + todo = filter(lambda sp: not sp[0].py_is_final(), todo) return output, loss + def _begin_update(self, states, tokvecs, drop=0.): + nr_class = self.moves.n_moves + attr_names = self.model.ops.allocate((2,), dtype='i') + attr_names[0] = TAG + attr_names[1] = DEP + + features = self._get_features(states, tokvecs, attr_names) + scores, finish_update = self.model.begin_update(features, drop=drop) + is_valid = self.model.ops.allocate((len(states), nr_class), dtype='i') + self._validate_batch(is_valid, states) + softmaxed = self.model.ops.softmax(scores) + softmaxed *= is_valid + softmaxed /= softmaxed.sum(axis=1) + print('Scores', softmaxed[0]) + def backward(golds, sgd=None): + costs = self.model.ops.allocate((len(states), nr_class), dtype='f') + d_scores = self.model.ops.allocate((len(states), nr_class), dtype='f') + + self._cost_batch(costs, is_valid, states, golds) + self._set_gradient(d_scores, scores, is_valid, costs) + return finish_update(d_scores, sgd=sgd) + return softmaxed, backward + def _init_states(self, docs): states = [] cdef Doc doc @@ -281,20 +295,20 @@ cdef class Parser: action = self.moves.c[guess] action.do(state.c, action.label) - def _set_gradient(self, gradients, scores, costs): + def _set_gradient(self, gradients, scores, is_valid, costs): """Do multi-label log loss""" cdef double Z, gZ, max_, g_max - g_scores = scores * (costs <= 0) - maxes = scores.max(axis=1).reshape((scores.shape[0], 1)) - g_maxes = g_scores.max(axis=1).reshape((g_scores.shape[0], 1)) - exps = numpy.exp((scores-maxes)) - g_exps = numpy.exp(g_scores-g_maxes) - - Zs = exps.sum(axis=1).reshape((exps.shape[0], 1)) - gZs = g_exps.sum(axis=1).reshape((g_exps.shape[0], 1)) - logprob = exps / Zs - g_logprob = g_exps / gZs - gradients[:] = logprob - g_logprob + scores = scores * is_valid + g_scores = scores * is_valid * (costs <= 0.) + exps = numpy.exp(scores - scores.max(axis=1)) + exps *= is_valid + g_exps = numpy.exp(g_scores - g_scores.max(axis=1)) + g_exps *= costs <= 0. + g_exps *= is_valid + gradients[:] = exps / exps.sum(axis=1) + gradients -= g_exps / g_exps.sum(axis=1) + print('Gradient', gradients[0]) + print('Costs', costs[0]) def step_through(self, Doc doc, GoldParse gold=None): """ diff --git a/spacy/syntax/stateclass.pyx b/spacy/syntax/stateclass.pyx index 5dedd1501..0e0b63d09 100644 --- a/spacy/syntax/stateclass.pyx +++ b/spacy/syntax/stateclass.pyx @@ -34,7 +34,7 @@ cdef class StateClass: def token_vector_lenth(self): return self.doc.tensor.shape[1] - def is_final(self): + def py_is_final(self): return self.c.is_final() def print_state(self, words): @@ -47,31 +47,38 @@ cdef class StateClass: return ' '.join((third, second, top, '|', n0, n1)) def nr_context_tokens(self, int nF, int nB, int nS, int nL, int nR): - return 1+nF+nB+nS + nL + (nS * nL) + (nS * nR) + return 3 + #return 1+nF+nB+nS + nL + (nS * nL) + (nS * nR) def set_context_tokens(self, int[:] output, nF=1, nB=0, nS=2, nL=2, nR=2): output[0] = self.B(0) output[1] = self.S(0) output[2] = self.S(1) - output[3] = self.L(self.S(0), 1) - output[4] = self.L(self.S(0), 2) - output[5] = self.R(self.S(0), 1) - output[6] = self.R(self.S(0), 2) - output[7] = self.L(self.S(1), 1) - output[8] = self.L(self.S(1), 2) - output[9] = self.R(self.S(1), 1) - output[10] = self.R(self.S(1), 2) + #output[3] = self.L(self.S(0), 1) + #output[4] = self.L(self.S(0), 2) + #output[5] = self.R(self.S(0), 1) + #output[6] = self.R(self.S(0), 2) + #output[7] = self.L(self.S(1), 1) + #output[8] = self.L(self.S(1), 2) + #output[9] = self.R(self.S(1), 1) + #output[10] = self.R(self.S(1), 2) def set_attributes(self, uint64_t[:, :] vals, int[:] tokens, int[:] names): cdef int i, j, tok_i for i in range(tokens.shape[0]): tok_i = tokens[i] - token = &self.c._sent[tok_i] - for j in range(names.shape[0]): - vals[i, j] = Token.get_struct_attr(token, names[j]) + if tok_i >= 0: + token = &self.c._sent[tok_i] + for j in range(names.shape[0]): + vals[i, j] = Token.get_struct_attr(token, names[j]) + else: + vals[i] = 0 def set_token_vectors(self, float[:, :] tokvecs, float[:, :] all_tokvecs, int[:] indices): for i in range(indices.shape[0]): - tokvecs[i] = all_tokvecs[indices[i]] + if indices[i] >= 0: + tokvecs[i] = all_tokvecs[indices[i]] + else: + tokvecs[i] = 0