Fix beam parser. Starting to work

This commit is contained in:
Matthew Honnibal 2016-07-24 01:14:56 +02:00
parent e2a9a68b66
commit 27176c3d2f
2 changed files with 146 additions and 194 deletions

View File

@ -26,10 +26,11 @@ from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t
from util import Config
from thinc.linear.features cimport ConjunctionExtracter
from thinc.structs cimport FeatureC
from thinc.structs cimport FeatureC, ExampleC
from thinc.extra.search cimport Beam
from thinc.extra.search cimport MaxViolation
from thinc.extra.eg cimport Example
from ..structs cimport TokenC
@ -46,6 +47,7 @@ from ._parse_features cimport fill_context
from .stateclass cimport StateClass
from .parser cimport Parser
from .parser cimport ParserPerceptron
from .parser cimport ParserNeuralNet
DEBUG = False
def set_debug(val):
@ -78,7 +80,6 @@ cdef class BeamParser(Parser):
self._parseC(tokens, length, nr_feat, nr_class)
cdef int _parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) except -1:
cdef Beam beam = Beam(self.moves.n_moves, self.beam_width)
beam.initialize(_init_state, length, tokens)
beam.check_done(_check_final_state, NULL)
@ -104,34 +105,39 @@ cdef class BeamParser(Parser):
while not pred.is_done and not gold.is_done:
self._advance_beam(pred, gold_parse, False)
self._advance_beam(gold, gold_parse, True)
violn.check(pred, gold)
self.model.time += 1
if pred.is_done and pred.loss == 0:
pass
elif pred.is_done and pred.loss > 0:
self._update(tokens, pred.histories[0], -1.0)
self._update(tokens, gold.histories[0], 1.0)
elif violn.cost > 0:
self._update(tokens, violn.p_hist, -1.0)
self._update(tokens, violn.g_hist, 1.0)
if pred.min_score > gold.score:
break
#print(pred.score, pred.min_score, gold.score)
cdef long double Z = 0.0
for i in range(pred.size):
if pred._states[i].loss > 0:
Z += exp(pred._states[i].score)
if Z > 0:
Z += exp(gold.score)
for i, hist in enumerate(pred.histories):
if pred._states[i].loss > 0:
self._update_dense(tokens, hist, exp(pred._states[i].score) / Z)
self._update_dense(tokens, gold.histories[0], (exp(gold.score) / Z) - 1)
_cleanup(pred)
_cleanup(gold)
return pred.loss
def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold):
cdef atom_t[CONTEXT_SIZE] context
cdef Pool mem = Pool()
features = <FeatureC*>mem.alloc(self.model.nr_feat, sizeof(FeatureC))
cdef ParserPerceptron model = self.model
cdef Example py_eg = Example(nr_class=self.moves.n_moves, nr_atom=CONTEXT_SIZE,
nr_feat=self.model.nr_feat, widths=self.model.widths)
cdef ExampleC* eg = py_eg.c
cdef ParserNeuralNet model = self.model
for i in range(beam.size):
py_eg.reset()
stcls = <StateClass>beam.at(i)
if not stcls.c.is_final():
fill_context(context, stcls.c)
nr_feat = model.extracter.set_features(features, context)
self.model.set_scoresC(beam.scores[i], features, nr_feat, 1)
model.set_featuresC(eg, stcls.c)
model.set_scoresC(beam.scores[i], eg.features, eg.nr_feat, 1)
self.moves.set_valid(beam.is_valid[i], stcls.c)
if gold is not None:
for i in range(beam.size):
py_eg.reset()
stcls = <StateClass>beam.at(i)
if not stcls.c.is_final():
self.moves.set_costs(beam.is_valid[i], beam.costs[i], stcls, gold)
@ -141,88 +147,24 @@ cdef class BeamParser(Parser):
beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
beam.check_done(_check_final_state, NULL)
def _maxent_update_dense(self, doc, pred_scores, pred_hist, gold_scores,
gold_hist, step_size=0.001):
for i, history in enumerate(pred_hist):
stcls = StateClass.init(doc.c, doc.length)
self.moves.initialize_state(stcls.c)
for j, clas in enumerate(history):
fill_context(context, stcls.c)
nr_feat = model.extracter.set_features(features, context)
self.moves.set_valid(is_valid, stcls)
# Move weight away from this outcome
for i in range(nr_class):
costs[i] = 0.0
costs[clas] = 1.0
self.update(features, nr_feat, True, costs, is_valid, False)
self.moves.c[clas].do(stcls.c, self.moves.c[clas].label)
for i, history in enumerate(gold_hist):
stcls = StateClass.init(doc.c, doc.length)
self.moves.initialize_state(stcls.c)
for j, clas in enumerate(history):
fill_context(context, stcls.c)
nr_feat = model.extracter.set_features(features, context)
self.moves.set_valid(is_valid, stcls)
# Move weight towards this outcome
for i in range(nr_class):
costs[i] = 1.0
costs[clas] = 0.0
self.update(features, nr_feat, True, costs, is_valid, False)
self.moves.c[clas].do(stcls.c, self.moves.c[clas].label)
def _maxent_update(self, doc, pred_scores, pred_hist, gold_scores, gold_hist,
step_size=0.001):
cdef weight_t Z, gZ, value
cdef feat_t feat
cdef class_t clas
gZ, g_counts = self._maxent_counts(doc, gold_scores, gold_hist)
Z, counts = self._maxent_counts(doc, pred_scores, pred_hist)
update = {}
if gZ > 0:
for (clas, feat), value in g_counts.items():
update[(clas, feat)] = value / gZ
Z += gZ
for (clas, feat), value in counts.items():
update.setdefault((clas, feat), 0.0)
update[(clas, feat)] -= value / Z
for (clas, feat), value in update.items():
if value < 1000:
self.model.update_weight(feat, clas, step_size * value)
def _maxent_counts(self, Doc doc, scores, history):
cdef Pool mem = Pool()
cdef atom_t[CONTEXT_SIZE] context
features = <FeatureC*>mem.alloc(self.model.nr_feat, sizeof(FeatureC))
cdef StateClass stcls
cdef class_t clas
cdef ParserPerceptron model = self.model
cdef weight_t Z = 0.0
cdef weight_t score
counts = {}
for i, (score, history) in enumerate(zip(scores, history)):
prob = exp(score)
if prob < 1e-6:
continue
stcls = StateClass.init(doc.c, doc.length)
self.moves.initialize_state(stcls.c)
for clas in history:
fill_context(context, stcls.c)
nr_feat = model.extracter.set_features(features, context)
for feat in features[:nr_feat]:
key = (clas, feat.key)
counts[key] = counts.get(key, 0.0) + feat.value
self.moves.c[clas].do(stcls.c, self.moves.c[clas].label)
for key in counts:
counts[key] *= prob
Z += prob
return Z, counts
def _update_dense(self, Doc doc, history, weight_t loss):
cdef Example py_eg = Example(nr_class=self.moves.n_moves,
nr_atom=CONTEXT_SIZE,
nr_feat=self.model.nr_feat,
widths=self.model.widths)
cdef ExampleC* eg = py_eg.c
cdef ParserNeuralNet model = self.model
stcls = StateClass.init(doc.c, doc.length)
self.moves.initialize_state(stcls.c)
for clas in history:
model.set_featuresC(eg, stcls.c)
self.moves.set_valid(eg.is_valid, stcls.c)
for i in range(self.moves.n_moves):
eg.costs[i] = loss if i == clas else 0
model.updateC(
eg.features, eg.nr_feat, True, eg.costs, eg.is_valid, False)
self.moves.c[clas].do(stcls.c, self.moves.c[clas].label)
py_eg.reset()
def _update(self, Doc tokens, list hist, weight_t inc):
cdef Pool mem = Pool()
@ -278,7 +220,88 @@ cdef hash_t _hash_state(void* _state, void* _) except 0:
#return <uint64_t>state.c
return state.c.hash()
#
# def _maxent_update(self, Doc doc, pred_scores, pred_hist, gold_scores, gold_hist):
# Z = 0
# for i, (score, history) in enumerate(zip(pred_scores, pred_hist)):
# prob = exp(score)
# if prob < 1e-6:
# continue
# stcls = StateClass.init(doc.c, doc.length)
# self.moves.initialize_state(stcls.c)
# for clas in history:
# delta_loss[clas] = prob * 1/Z
# gradient = [(input_ * prob) / Z for input_ in hidden]
# fill_context(context, stcls.c)
# nr_feat = model.extracter.set_features(features, context)
# for feat in features[:nr_feat]:
# key = (clas, feat.key)
# counts[key] = counts.get(key, 0.0) + feat.value
# self.moves.c[clas].do(stcls.c, self.moves.c[clas].label)
# for key in counts:
# counts[key] *= prob
# Z += prob
# gZ, g_counts = self._maxent_counts(doc, gold_scores, gold_hist)
# for (clas, feat), value in g_counts.items():
# self.model.update_weight(feat, clas, value / gZ)
#
# Z, counts = self._maxent_counts(doc, pred_scores, pred_hist)
# for (clas, feat), value in counts.items():
# self.model.update_weight(feat, clas, -value / (Z + gZ))
#
#
# def _maxent_update(self, doc, pred_scores, pred_hist, gold_scores, gold_hist,
# step_size=0.001):
# cdef weight_t Z, gZ, value
# cdef feat_t feat
# cdef class_t clas
# gZ, g_counts = self._maxent_counts(doc, gold_scores, gold_hist)
# Z, counts = self._maxent_counts(doc, pred_scores, pred_hist)
# update = {}
# if gZ > 0:
# for (clas, feat), value in g_counts.items():
# update[(clas, feat)] = value / gZ
# Z += gZ
# for (clas, feat), value in counts.items():
# update.setdefault((clas, feat), 0.0)
# update[(clas, feat)] -= value / Z
# for (clas, feat), value in update.items():
# if value < 1000:
# self.model.update_weight(feat, clas, step_size * value)
#
# def _maxent_counts(self, Doc doc, scores, history):
# cdef Pool mem = Pool()
# cdef atom_t[CONTEXT_SIZE] context
# features = <FeatureC*>mem.alloc(self.model.nr_feat, sizeof(FeatureC))
#
# cdef StateClass stcls
#
# cdef class_t clas
# cdef ParserPerceptron model = self.model
#
# cdef weight_t Z = 0.0
# cdef weight_t score
# counts = {}
# for i, (score, history) in enumerate(zip(scores, history)):
# prob = exp(score)
# if prob < 1e-6:
# continue
# stcls = StateClass.init(doc.c, doc.length)
# self.moves.initialize_state(stcls.c)
# for clas in history:
# fill_context(context, stcls.c)
# nr_feat = model.extracter.set_features(features, context)
# for feat in features[:nr_feat]:
# key = (clas, feat.key)
# counts[key] = counts.get(key, 0.0) + feat.value
# self.moves.c[clas].do(stcls.c, self.moves.c[clas].label)
# for key in counts:
# counts[key] *= prob
# Z += prob
# return Z, counts
#
#
# def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold, words):
# cdef atom_t[CONTEXT_SIZE] context

View File

@ -13,6 +13,7 @@ from cpython.exc cimport PyErr_CheckSignals
from libc.stdint cimport uint32_t, uint64_t
from libc.string cimport memset, memcpy
from libc.stdlib cimport malloc, calloc, free
from libc.math cimport exp
import os.path
from os import path
import shutil
@ -106,7 +107,7 @@ cdef class ParserPerceptron(AveragedPerceptron):
cdef class ParserNeuralNet(NeuralNet):
def __init__(self, shape, **kwargs):
vector_widths = [4] * 57
vector_widths = [4] * 76
slots = [0, 1, 2, 3] # S0
slots += [4, 5, 6, 7] # S1
slots += [8, 9, 10, 11] # S2
@ -119,11 +120,10 @@ cdef class ParserNeuralNet(NeuralNet):
slots += [36, 37, 38, 39] * 2 # B0l, B0r
slots += [40, 41, 42, 43] * 2 # S1l, S1r
slots += [44, 45, 46, 47] * 2 # S2l, S2r
slots += [48, 49, 50, 51, 52]
slots += [48, 49, 50, 51, 52, 53, 54, 55]
slots += [53, 54, 55, 56]
input_length = sum(vector_widths[slot] for slot in slots)
widths = [input_length] + shape[3:]
widths = [input_length] + shape
NeuralNet.__init__(self, widths, embed=(vector_widths, slots), **kwargs)
@property
@ -156,15 +156,26 @@ cdef class ParserNeuralNet(NeuralNet):
feats = _add_pos_bigram(feats, 65, state.S_(1), state.S_(0))
feats = _add_pos_bigram(feats, 66, state.S_(1), state.B_(0))
feats = _add_pos_bigram(feats, 67, state.S_(0), state.B_(1))
feats = _add_pos_bigram(feats, 68, state.B_(0), state.B_(1))
feats = _add_pos_trigram(feats, 69, state.S_(1), state.S_(0), state.B_(0))
feats = _add_pos_trigram(feats, 70, state.S_(0), state.B_(0), state.B_(1))
feats = _add_pos_trigram(feats, 71, state.S_(0), state.R_(state.S(0), 1),
feats = _add_pos_bigram(feats, 68, state.S_(0), state.R_(state.S(0), 1))
feats = _add_pos_bigram(feats, 69, state.S_(0), state.R_(state.S(0), 2))
feats = _add_pos_bigram(feats, 70, state.S_(0), state.L_(state.S(0), 1))
feats = _add_pos_bigram(feats, 71, state.S_(0), state.L_(state.S(0), 2))
feats = _add_pos_trigram(feats, 72, state.S_(1), state.S_(0), state.B_(0))
feats = _add_pos_trigram(feats, 73, state.S_(0), state.B_(0), state.B_(1))
feats = _add_pos_trigram(feats, 74, state.S_(0), state.R_(state.S(0), 1),
state.R_(state.S(0), 2))
feats = _add_pos_trigram(feats, 72, state.S_(0), state.L_(state.S(0), 1),
feats = _add_pos_trigram(feats, 75, state.S_(0), state.L_(state.S(0), 1),
state.L_(state.S(0), 2))
eg.nr_feat = feats - eg.features
cdef void _set_delta_lossC(self, weight_t* delta_loss,
const weight_t* Zs, const weight_t* scores) nogil:
for i in range(self.c.widths[self.c.nr_layer-1]):
delta_loss[i] = Zs[i]
cdef void _softmaxC(self, weight_t* out) nogil:
pass
cdef inline FeatureC* _add_token(FeatureC* feats,
int slot, const TokenC* token, weight_t value) nogil:
@ -230,80 +241,6 @@ cdef inline FeatureC* _add_pos_trigram(FeatureC* feat, int slot,
feat.value = 1.0
return feat+1
cdef class ParserNeuralNetEnsemble(ParserNeuralNet):
def __init__(self, shape, update_step='sgd', eta=0.01, rho=0.0, n=5):
ParserNeuralNet.__init__(self, shape, update_step=update_step, eta=eta, rho=rho)
self._models_c = <NeuralNetC**>self.mem.alloc(sizeof(NeuralNetC*), n)
self._masks = <int**>self.mem.alloc(sizeof(int*), n)
self._models = []
cdef ParserNeuralNet model
threshold = 1.5 / n
self._nr_model = n
for i in range(n):
self._masks[i] = <int*>self.mem.alloc(sizeof(int), self.nr_feat)
for j in range(self.nr_feat):
self._masks[i][j] = random.random() < threshold
# We have to pass our pool here, because the embedding table passes
# it around.
model = ParserNeuralNet(shape, update_step=update_step, eta=eta, rho=rho)
self._models_c[i] = &model.c
self._models.append(model)
property eta:
def __get__(self):
return self._models[0].eta
def __set__(self, weight_t value):
for model in self._models:
model.eta = value
def sparsify_embeddings(self, penalty):
p = 0.0
for model in self._models:
p += model.sparsify_embeddings(penalty)
return p / len(self._models)
cdef void set_scoresC(self, weight_t* scores, const void* _feats,
int nr_feat, int is_sparse) nogil:
nr_class = self.c.widths[self.c.nr_layer-1]
sub_scores = <weight_t*>calloc(sizeof(weight_t), nr_class)
sub_feats = <FeatureC*>calloc(sizeof(FeatureC), nr_feat)
feats = <const FeatureC*>_feats
for i in range(self._nr_model):
for j in range(nr_feat):
sub_feats[j] = feats[j]
sub_feats[j].value *= self._masks[i][j]
self.c = self._models_c[i][0]
self.c.weights = self._models_c[i].weights
self.c.gradient = self._models_c[i].gradient
ParserNeuralNet.set_scoresC(self, sub_scores, sub_feats, nr_feat, 1)
for j in range(nr_class):
scores[j] += sub_scores[j]
sub_scores[j] = 0.0
for j in range(nr_class):
scores[j] /= self._nr_model
free(sub_feats)
free(sub_scores)
def update(self, Example eg):
if eg.cost == 0:
return 0.0
loss = 0.0
full_feats = <FeatureC*>calloc(sizeof(FeatureC), eg.nr_feat)
memcpy(full_feats, eg.c.features, sizeof(FeatureC) * eg.nr_feat)
cdef ParserNeuralNet model
for i, model in enumerate(self._models):
for j in range(eg.nr_feat):
eg.c.features[j].value *= self._masks[i][j]
loss += model.update(eg)
memcpy(eg.c.features, full_feats, sizeof(FeatureC) * eg.nr_feat)
free(full_feats)
return loss
def end_training(self):
for model in self._models:
model.end_training()
cdef class Parser:
def __init__(self, StringStore strings, transition_system, model):
@ -320,16 +257,8 @@ cdef class Parser:
moves = transition_system(strings, cfg.labels)
if cfg.get('model') == 'neural':
shape = [cfg.vector_widths, cfg.slots, cfg.feat_set]
shape.extend(cfg.hidden_layers)
shape.append(moves.n_moves)
if cfg.get('ensemble_size') >= 2:
model = ParserNeuralNetEnsemble(shape, update_step=cfg.update_step,
eta=cfg.eta, rho=cfg.rho,
n=cfg.ensemble_size)
else:
model = ParserNeuralNet(shape, update_step=cfg.update_step,
eta=cfg.eta, rho=cfg.rho)
model = ParserNeuralNet(cfg.hidden_layers + [moves.n_moves],
update_step=cfg.update_step, eta=cfg.eta, rho=cfg.rho)
else:
model = ParserPerceptron(get_templates(cfg.feat_set))