Gradients look correct

This commit is contained in:
Matthew Honnibal 2017-05-06 16:47:15 +02:00
parent 7e04260d38
commit 8e48b58cd6
4 changed files with 173 additions and 88 deletions

View File

@ -1,4 +1,4 @@
from __future__ import unicode_literals from __future__ import unicode_literals, print_function
import plac import plac
import json import json
import random import random
@ -9,7 +9,7 @@ from spacy.syntax.nonproj import PseudoProjectivity
from spacy.language import Language from spacy.language import Language
from spacy.gold import GoldParse from spacy.gold import GoldParse
from spacy.tagger import Tagger from spacy.tagger import Tagger
from spacy.pipeline import DependencyParser, BeamDependencyParser from spacy.pipeline import DependencyParser, TokenVectorEncoder
from spacy.syntax.parser import get_templates from spacy.syntax.parser import get_templates
from spacy.syntax.arc_eager import ArcEager from spacy.syntax.arc_eager import ArcEager
from spacy.scorer import Scorer from spacy.scorer import Scorer
@ -36,10 +36,10 @@ def read_conllx(loc, n=0):
try: try:
id_ = int(id_) - 1 id_ = int(id_) - 1
head = (int(head) - 1) if head != '0' else id_ head = (int(head) - 1) if head != '0' else id_
dep = 'ROOT' if dep == 'root' else dep dep = 'ROOT' if dep == 'root' else 'unlabelled'
tokens.append((id_, word, tag, head, dep, 'O')) # Hack for efficiency
tokens.append((id_, word, pos+'__'+morph, head, dep, 'O'))
except: except:
print(line)
raise raise
tuples = [list(t) for t in zip(*tokens)] tuples = [list(t) for t in zip(*tokens)]
yield (None, [[tuples, []]]) yield (None, [[tuples, []]])
@ -48,19 +48,37 @@ def read_conllx(loc, n=0):
break break
def score_model(vocab, tagger, parser, gold_docs, verbose=False): def score_model(vocab, encoder, tagger, parser, Xs, ys, verbose=False):
scorer = Scorer() scorer = Scorer()
for _, gold_doc in gold_docs: correct = 0.
for (ids, words, tags, heads, deps, entities), _ in gold_doc: total = 0.
doc = Doc(vocab, words=words) for doc, gold in zip(Xs, ys):
tagger(doc) doc = Doc(vocab, words=[w.text for w in doc])
parser(doc) encoder(doc)
PseudoProjectivity.deprojectivize(doc) tagger(doc)
gold = GoldParse(doc, tags=tags, heads=heads, deps=deps) parser(doc)
scorer.score(doc, gold, verbose=verbose) PseudoProjectivity.deprojectivize(doc)
scorer.score(doc, gold, verbose=verbose)
for token, tag in zip(doc, gold.tags):
univ_guess, _ = token.tag_.split('_', 1)
univ_truth, _ = tag.split('_', 1)
correct += univ_guess == univ_truth
total += 1
return scorer return scorer
def organize_data(vocab, train_sents):
Xs = []
ys = []
for _, doc_sents in train_sents:
for (ids, words, tags, heads, deps, ner), _ in doc_sents:
doc = Doc(vocab, words=words)
gold = GoldParse(doc, tags=tags, heads=heads, deps=deps)
Xs.append(doc)
ys.append(gold)
return Xs, ys
def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None): def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None):
LangClass = spacy.util.get_lang_class(lang_name) LangClass = spacy.util.get_lang_class(lang_name)
train_sents = list(read_conllx(train_loc)) train_sents = list(read_conllx(train_loc))
@ -114,21 +132,37 @@ def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None):
for tag in tags: for tag in tags:
assert tag in vocab.morphology.tag_map, repr(tag) assert tag in vocab.morphology.tag_map, repr(tag)
tagger = Tagger(vocab) tagger = Tagger(vocab)
encoder = TokenVectorEncoder(vocab)
parser = DependencyParser(vocab, actions=actions, features=features, L1=0.0) parser = DependencyParser(vocab, actions=actions, features=features, L1=0.0)
for itn in range(30):
loss = 0. Xs, ys = organize_data(vocab, train_sents)
for _, doc_sents in train_sents: Xs = Xs[:1]
for (ids, words, tags, heads, deps, ner), _ in doc_sents: ys = ys[:1]
doc = Doc(vocab, words=words) with encoder.model.begin_training(Xs[:100], ys[:100]) as (trainer, optimizer):
gold = GoldParse(doc, tags=tags, heads=heads, deps=deps) docs = list(Xs)
tagger(doc) for doc in docs:
loss += parser.update(doc, gold, itn=itn) encoder(doc)
doc = Doc(vocab, words=words) parser.begin_training(docs, ys)
nn_loss = [0.]
def track_progress():
scorer = score_model(vocab, encoder, tagger, parser, Xs, ys)
itn = len(nn_loss)
print('%d:\t%.3f\t%.3f\t%.3f' % (itn, nn_loss[-1], scorer.uas, scorer.tags_acc))
nn_loss.append(0.)
trainer.each_epoch.append(track_progress)
trainer.batch_size = 1
trainer.nb_epoch = 100
for docs, golds in trainer.iterate(Xs, ys, progress_bar=False):
docs = [Doc(vocab, words=[w.text for w in doc]) for doc in docs]
tokvecs, upd_tokvecs = encoder.begin_update(docs)
for doc, tokvec in zip(docs, tokvecs):
doc.tensor = tokvec
for doc, gold in zip(docs, golds):
tagger.update(doc, gold) tagger.update(doc, gold)
random.shuffle(train_sents) d_tokvecs, loss = parser.update(docs, golds, sgd=optimizer)
scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc)) upd_tokvecs(d_tokvecs, sgd=optimizer)
print('%d:\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas, scorer.tags_acc)) nn_loss[-1] += loss
nlp = LangClass(vocab=vocab, tagger=tagger, parser=parser) nlp = LangClass(vocab=vocab, tagger=tagger, parser=parser)
nlp.end_training(model_dir) nlp.end_training(model_dir)
scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc)) scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc))

View File

@ -1,5 +1,5 @@
from thinc.api import layerize, chain, clone, concatenate, with_flatten from thinc.api import layerize, chain, clone, concatenate, with_flatten
from thinc.neural import Model, Maxout, Softmax from thinc.neural import Model, Maxout, Softmax, Affine
from thinc.neural._classes.hash_embed import HashEmbed from thinc.neural._classes.hash_embed import HashEmbed
from thinc.neural._classes.convolution import ExtractWindow from thinc.neural._classes.convolution import ExtractWindow
@ -21,11 +21,41 @@ def build_model(state2vec, width, depth, nr_class):
state2vec state2vec
>> Maxout(width, 1344) >> Maxout(width, 1344)
>> Maxout(width, width) >> Maxout(width, width)
>> Softmax(nr_class, width) >> Affine(nr_class, width)
) )
return model return model
def build_debug_model(state2vec, width, depth, nr_class):
with Model.define_operators({'>>': chain, '**': clone}):
model = (
state2vec
>> Maxout(width)
>> Affine(nr_class)
)
return model
def build_debug_state2vec(width, nr_vector=1000, nF=1, nB=0, nS=1, nL=2, nR=2):
ops = Model.ops
def forward(tokens_attrs_vectors, drop=0.):
tokens, attr_vals, tokvecs = tokens_attrs_vectors
orig_tokvecs_shape = tokvecs.shape
tokvecs = tokvecs.reshape((tokvecs.shape[0], tokvecs.shape[1] *
tokvecs.shape[2]))
vector = tokvecs
def backward(d_vector, sgd=None):
d_tokvecs = vector.reshape(orig_tokvecs_shape)
return (tokens, d_tokvecs)
return vector, backward
model = layerize(forward)
return model
def build_parser_state2vec(width, nr_vector=1000, nF=1, nB=0, nS=1, nL=2, nR=2): def build_parser_state2vec(width, nr_vector=1000, nF=1, nB=0, nS=1, nL=2, nR=2):
embed_tags = _reshape(chain(get_col(0), HashEmbed(16, nr_vector))) embed_tags = _reshape(chain(get_col(0), HashEmbed(16, nr_vector)))
embed_deps = _reshape(chain(get_col(1), HashEmbed(16, nr_vector))) embed_deps = _reshape(chain(get_col(1), HashEmbed(16, nr_vector)))

View File

@ -28,6 +28,8 @@ from murmurhash.mrmr cimport hash64
from preshed.maps cimport MapStruct from preshed.maps cimport MapStruct
from preshed.maps cimport map_get from preshed.maps cimport map_get
from numpy import exp
from . import _parse_features from . import _parse_features
from ._parse_features cimport CONTEXT_SIZE from ._parse_features cimport CONTEXT_SIZE
from ._parse_features cimport fill_context from ._parse_features cimport fill_context
@ -43,6 +45,7 @@ from ..gold cimport GoldParse
from ..attrs cimport TAG, DEP from ..attrs cimport TAG, DEP
from .._ml import build_parser_state2vec, build_model from .._ml import build_parser_state2vec, build_model
from .._ml import build_debug_state2vec, build_debug_model
USE_FTRL = True USE_FTRL = True
@ -111,8 +114,8 @@ cdef class Parser:
return (Parser, (self.vocab, self.moves, self.model), None, None) return (Parser, (self.vocab, self.moves, self.model), None, None)
def build_model(self, width=8, nr_vector=1000, nF=1, nB=1, nS=1, nL=1, nR=1, **_): def build_model(self, width=8, nr_vector=1000, nF=1, nB=1, nS=1, nL=1, nR=1, **_):
state2vec = build_parser_state2vec(width, nr_vector, nF, nB, nL, nR) state2vec = build_debug_state2vec(width, nr_vector, nF, nB, nL, nR)
model = build_model(state2vec, width, 2, self.moves.n_moves) model = build_debug_model(state2vec, width, 2, self.moves.n_moves)
return model return model
def __call__(self, Doc tokens): def __call__(self, Doc tokens):
@ -166,32 +169,22 @@ cdef class Parser:
cdef Doc doc cdef Doc doc
cdef StateClass state cdef StateClass state
cdef int guess cdef int guess
is_valid = self.model.ops.allocate((len(docs), nr_class), dtype='i')
tokvecs = [d.tensor for d in docs] tokvecs = [d.tensor for d in docs]
attr_names = self.model.ops.allocate((2,), dtype='i')
attr_names[0] = TAG
attr_names[1] = DEP
all_states = list(states) all_states = list(states)
todo = zip(states, tokvecs) todo = zip(states, tokvecs)
while todo: while todo:
states, tokvecs = zip(*todo) states, tokvecs = zip(*todo)
features = self._get_features(states, tokvecs, attr_names) scores, _ = self._begin_update(states, tokvecs)
scores = self.model.predict(features)
self._validate_batch(is_valid, states)
scores *= is_valid
for state, guess in zip(states, scores.argmax(axis=1)): for state, guess in zip(states, scores.argmax(axis=1)):
action = self.moves.c[guess] action = self.moves.c[guess]
action.do(state.c, action.label) action.do(state.c, action.label)
todo = filter(lambda sp: not sp[0].is_final(), todo) todo = filter(lambda sp: not sp[0].py_is_final(), todo)
for state, doc in zip(all_states, docs): for state, doc in zip(all_states, docs):
self.moves.finalize_state(state.c) self.moves.finalize_state(state.c)
for i in range(doc.length): for i in range(doc.length):
doc.c[i] = state.c._sent[i] doc.c[i] = state.c._sent[i]
def begin_training(self, docs, golds):
def update(self, docs, golds, drop=0., sgd=None):
if isinstance(docs, Doc) and isinstance(golds, GoldParse):
return self.update([docs], [golds], drop=drop)
for gold in golds: for gold in golds:
self.moves.preprocess_gold(gold) self.moves.preprocess_gold(gold)
states = self._init_states(docs) states = self._init_states(docs)
@ -204,39 +197,60 @@ cdef class Parser:
attr_names = self.model.ops.allocate((2,), dtype='i') attr_names = self.model.ops.allocate((2,), dtype='i')
attr_names[0] = TAG attr_names[0] = TAG
attr_names[1] = DEP attr_names[1] = DEP
features = self._get_features(states, tokvecs, attr_names)
self.model.begin_training(features)
def update(self, docs, golds, drop=0., sgd=None):
if isinstance(docs, Doc) and isinstance(golds, GoldParse):
return self.update([docs], [golds], drop=drop)
for gold in golds:
self.moves.preprocess_gold(gold)
states = self._init_states(docs)
tokvecs = [d.tensor for d in docs]
d_tokens = [self.model.ops.allocate(d.tensor.shape) for d in docs]
nr_class = self.moves.n_moves
output = list(d_tokens) output = list(d_tokens)
todo = zip(states, tokvecs, golds, d_tokens) todo = zip(states, tokvecs, golds, d_tokens)
assert len(states) == len(todo) assert len(states) == len(todo)
loss = 0. loss = 0.
while todo: while todo:
states, tokvecs, golds, d_tokens = zip(*todo) states, tokvecs, golds, d_tokens = zip(*todo)
features = self._get_features(states, tokvecs, attr_names) scores, finish_update = self._begin_update(states, tokvecs)
token_ids, batch_token_grads = finish_update(golds, sgd=sgd)
scores, finish_update = self.model.begin_update(features, drop=drop)
assert scores.shape == (len(states), self.moves.n_moves), (len(states), scores.shape)
self._cost_batch(costs, is_valid, states, golds)
scores *= is_valid
self._set_gradient(gradients, scores, costs)
loss += numpy.abs(gradients).sum() / gradients.shape[0]
token_ids, batch_token_grads = finish_update(gradients, sgd=sgd)
for i, tok_i in enumerate(token_ids): for i, tok_i in enumerate(token_ids):
d_tokens[i][tok_i] += batch_token_grads[i] d_tokens[i][tok_i] += batch_token_grads[i]
self._transition_batch(states, scores) self._transition_batch(states, scores)
# Get unfinished states (and their matching gold and token gradients) # Get unfinished states (and their matching gold and token gradients)
todo = filter(lambda sp: not sp[0].is_final(), todo) todo = filter(lambda sp: not sp[0].py_is_final(), todo)
costs = costs[:len(todo)]
is_valid = is_valid[:len(todo)]
gradients = gradients[:len(todo)]
gradients.fill(0)
costs.fill(0)
is_valid.fill(1)
return output, loss return output, loss
def _begin_update(self, states, tokvecs, drop=0.):
nr_class = self.moves.n_moves
attr_names = self.model.ops.allocate((2,), dtype='i')
attr_names[0] = TAG
attr_names[1] = DEP
features = self._get_features(states, tokvecs, attr_names)
scores, finish_update = self.model.begin_update(features, drop=drop)
is_valid = self.model.ops.allocate((len(states), nr_class), dtype='i')
self._validate_batch(is_valid, states)
softmaxed = self.model.ops.softmax(scores)
softmaxed *= is_valid
softmaxed /= softmaxed.sum(axis=1)
print('Scores', softmaxed[0])
def backward(golds, sgd=None):
costs = self.model.ops.allocate((len(states), nr_class), dtype='f')
d_scores = self.model.ops.allocate((len(states), nr_class), dtype='f')
self._cost_batch(costs, is_valid, states, golds)
self._set_gradient(d_scores, scores, is_valid, costs)
return finish_update(d_scores, sgd=sgd)
return softmaxed, backward
def _init_states(self, docs): def _init_states(self, docs):
states = [] states = []
cdef Doc doc cdef Doc doc
@ -281,20 +295,20 @@ cdef class Parser:
action = self.moves.c[guess] action = self.moves.c[guess]
action.do(state.c, action.label) action.do(state.c, action.label)
def _set_gradient(self, gradients, scores, costs): def _set_gradient(self, gradients, scores, is_valid, costs):
"""Do multi-label log loss""" """Do multi-label log loss"""
cdef double Z, gZ, max_, g_max cdef double Z, gZ, max_, g_max
g_scores = scores * (costs <= 0) scores = scores * is_valid
maxes = scores.max(axis=1).reshape((scores.shape[0], 1)) g_scores = scores * is_valid * (costs <= 0.)
g_maxes = g_scores.max(axis=1).reshape((g_scores.shape[0], 1)) exps = numpy.exp(scores - scores.max(axis=1))
exps = numpy.exp((scores-maxes)) exps *= is_valid
g_exps = numpy.exp(g_scores-g_maxes) g_exps = numpy.exp(g_scores - g_scores.max(axis=1))
g_exps *= costs <= 0.
Zs = exps.sum(axis=1).reshape((exps.shape[0], 1)) g_exps *= is_valid
gZs = g_exps.sum(axis=1).reshape((g_exps.shape[0], 1)) gradients[:] = exps / exps.sum(axis=1)
logprob = exps / Zs gradients -= g_exps / g_exps.sum(axis=1)
g_logprob = g_exps / gZs print('Gradient', gradients[0])
gradients[:] = logprob - g_logprob print('Costs', costs[0])
def step_through(self, Doc doc, GoldParse gold=None): def step_through(self, Doc doc, GoldParse gold=None):
""" """

View File

@ -34,7 +34,7 @@ cdef class StateClass:
def token_vector_lenth(self): def token_vector_lenth(self):
return self.doc.tensor.shape[1] return self.doc.tensor.shape[1]
def is_final(self): def py_is_final(self):
return self.c.is_final() return self.c.is_final()
def print_state(self, words): def print_state(self, words):
@ -47,31 +47,38 @@ cdef class StateClass:
return ' '.join((third, second, top, '|', n0, n1)) return ' '.join((third, second, top, '|', n0, n1))
def nr_context_tokens(self, int nF, int nB, int nS, int nL, int nR): def nr_context_tokens(self, int nF, int nB, int nS, int nL, int nR):
return 1+nF+nB+nS + nL + (nS * nL) + (nS * nR) return 3
#return 1+nF+nB+nS + nL + (nS * nL) + (nS * nR)
def set_context_tokens(self, int[:] output, nF=1, nB=0, nS=2, def set_context_tokens(self, int[:] output, nF=1, nB=0, nS=2,
nL=2, nR=2): nL=2, nR=2):
output[0] = self.B(0) output[0] = self.B(0)
output[1] = self.S(0) output[1] = self.S(0)
output[2] = self.S(1) output[2] = self.S(1)
output[3] = self.L(self.S(0), 1) #output[3] = self.L(self.S(0), 1)
output[4] = self.L(self.S(0), 2) #output[4] = self.L(self.S(0), 2)
output[5] = self.R(self.S(0), 1) #output[5] = self.R(self.S(0), 1)
output[6] = self.R(self.S(0), 2) #output[6] = self.R(self.S(0), 2)
output[7] = self.L(self.S(1), 1) #output[7] = self.L(self.S(1), 1)
output[8] = self.L(self.S(1), 2) #output[8] = self.L(self.S(1), 2)
output[9] = self.R(self.S(1), 1) #output[9] = self.R(self.S(1), 1)
output[10] = self.R(self.S(1), 2) #output[10] = self.R(self.S(1), 2)
def set_attributes(self, uint64_t[:, :] vals, int[:] tokens, int[:] names): def set_attributes(self, uint64_t[:, :] vals, int[:] tokens, int[:] names):
cdef int i, j, tok_i cdef int i, j, tok_i
for i in range(tokens.shape[0]): for i in range(tokens.shape[0]):
tok_i = tokens[i] tok_i = tokens[i]
token = &self.c._sent[tok_i] if tok_i >= 0:
for j in range(names.shape[0]): token = &self.c._sent[tok_i]
vals[i, j] = Token.get_struct_attr(token, <attr_id_t>names[j]) for j in range(names.shape[0]):
vals[i, j] = Token.get_struct_attr(token, <attr_id_t>names[j])
else:
vals[i] = 0
def set_token_vectors(self, float[:, :] tokvecs, def set_token_vectors(self, float[:, :] tokvecs,
float[:, :] all_tokvecs, int[:] indices): float[:, :] all_tokvecs, int[:] indices):
for i in range(indices.shape[0]): for i in range(indices.shape[0]):
tokvecs[i] = all_tokvecs[indices[i]] if indices[i] >= 0:
tokvecs[i] = all_tokvecs[indices[i]]
else:
tokvecs[i] = 0