Tmp. Working on NN NER.

This commit is contained in:
Matthew Honnibal 2016-09-08 13:00:13 +02:00
parent 7c7a05a466
commit b3b180010b
3 changed files with 102 additions and 49 deletions

View File

@ -23,6 +23,7 @@ from ._state cimport StateC
from ._parse_features cimport fill_context from ._parse_features cimport fill_context
from ._parse_features cimport CONTEXT_SIZE from ._parse_features cimport CONTEXT_SIZE
from ._parse_features cimport fill_context from ._parse_features cimport fill_context
from ._parse_features import ner as ner_templates
from ._parse_features cimport * from ._parse_features cimport *
from .transition_system cimport TransitionSystem from .transition_system cimport TransitionSystem
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
@ -45,6 +46,10 @@ cdef class ParserPerceptron(AveragedPerceptron):
# Clip to guess and best, to keep gradient sparse. # Clip to guess and best, to keep gradient sparse.
d_losses[guess] = -2 * (-eg.c.costs[guess] - eg.c.scores[guess]) d_losses[guess] = -2 * (-eg.c.costs[guess] - eg.c.scores[guess])
d_losses[best] = -2 * (-eg.c.costs[best] - eg.c.scores[best]) d_losses[best] = -2 * (-eg.c.costs[best] - eg.c.scores[best])
#for i in range(eg.c.nr_class):
# if eg.c.is_valid[i] \
# and eg.c.scores[i] >= eg.c.scores[best]:
# d_losses[i] = -2 * (-eg.c.costs[i] - eg.c.scores[i])
elif loss == 'nll': elif loss == 'nll':
# Clip to guess and best, to keep gradient sparse. # Clip to guess and best, to keep gradient sparse.
if eg.c.scores[guess] == 0.0: if eg.c.scores[guess] == 0.0:
@ -69,11 +74,11 @@ cdef class ParserPerceptron(AveragedPerceptron):
d_losses = {best: -1.0, guess: 1.0} d_losses = {best: -1.0, guess: 1.0}
step = 0.0 step = 0.0
i = 0 i = 0
for clas, d_loss in d_losses.items(): for clas, d_loss in sorted(d_losses.items()):
for feat in eg.c.features[:eg.c.nr_feat]: for feat in eg.c.features[:eg.c.nr_feat]:
step += abs(self.update_weight(feat.key, clas, feat.value * d_loss)) self.update_weight(feat.key, clas, feat.value * d_loss)
i += 1 i += 1
self.total_L1 += self.l1_penalty * self.learn_rate #self.total_L1 += self.l1_penalty * self.learn_rate
return sum(map(abs, d_losses.values())) return sum(map(abs, d_losses.values()))
cdef int set_featuresC(self, FeatureC* feats, const void* _state) nogil: cdef int set_featuresC(self, FeatureC* feats, const void* _state) nogil:
@ -95,38 +100,53 @@ cdef class ParserPerceptron(AveragedPerceptron):
for clas in history: for clas in history:
nr_feat = self.set_featuresC(features, stcls.c) nr_feat = self.set_featuresC(features, stcls.c)
for feat in features[:nr_feat]: for feat in features[:nr_feat]:
self.update_weight(feat.key, clas, feat.value * -grad) self.update_weight(feat.key, clas, feat.value * grad)
moves.c[clas].do(stcls.c, moves.c[clas].label) moves.c[clas].do(stcls.c, moves.c[clas].label)
cdef class ParserNeuralNet(NeuralNet): cdef class ParserNeuralNet(NeuralNet):
def __init__(self, shape, **kwargs): def __init__(self, shape, **kwargs):
vector_widths = [4] * 76 if kwargs.get('feat_set', 'parser') == 'parser':
slots = [0, 1, 2, 3] # S0 vector_widths = [4] * 76
slots += [4, 5, 6, 7] # S1 slots = [0, 1, 2, 3] # S0
slots += [8, 9, 10, 11] # S2 slots += [4, 5, 6, 7] # S1
slots += [12, 13, 14, 15] # S3+ slots += [8, 9, 10, 11] # S2
slots += [16, 17, 18, 19] # B0 slots += [12, 13, 14, 15] # S3+
slots += [20, 21, 22, 23] # B1 slots += [16, 17, 18, 19] # B0
slots += [24, 25, 26, 27] # B2 slots += [20, 21, 22, 23] # B1
slots += [28, 29, 30, 31] # B3+ slots += [24, 25, 26, 27] # B2
slots += [32, 33, 34, 35] * 2 # S0l, S0r slots += [28, 29, 30, 31] # B3+
slots += [36, 37, 38, 39] * 2 # B0l, B0r slots += [32, 33, 34, 35] * 2 # S0l, S0r
slots += [40, 41, 42, 43] * 2 # S1l, S1r slots += [36, 37, 38, 39] * 2 # B0l, B0r
slots += [44, 45, 46, 47] * 2 # S2l, S2r slots += [40, 41, 42, 43] * 2 # S1l, S1r
slots += [48, 49, 50, 51, 52, 53, 54, 55] slots += [44, 45, 46, 47] * 2 # S2l, S2r
slots += [53, 54, 55, 56] slots += [48, 49, 50, 51, 52, 53, 54, 55]
slots += [53, 54, 55, 56]
self.extracter = None
else:
templates = ner_templates
vector_widths = [4] * len(templates)
slots = list(range(templates))
self.extracter = ConjunctionExtracter(templates)
input_length = sum(vector_widths[slot] for slot in slots) input_length = sum(vector_widths[slot] for slot in slots)
widths = [input_length] + shape widths = [input_length] + shape
NeuralNet.__init__(self, widths, embed=(vector_widths, slots), **kwargs) NeuralNet.__init__(self, widths, embed=(vector_widths, slots), **kwargs)
@property @property
def nr_feat(self): def nr_feat(self):
return 2000 if self.extracter is None:
return 2000
else:
return self.extracter.nr_feat
cdef int set_featuresC(self, FeatureC* feats, const void* _state) nogil: cdef int set_featuresC(self, FeatureC* feats, const void* _state) nogil:
memset(feats, 0, 2000 * sizeof(FeatureC)) cdef atom_t[CONTEXT_SIZE] context
state = <const StateC*>_state state = <const StateC*>_state
if self.extracter is not None:
fill_context(context, state)
return self.extracter.set_features(feats, context)
memset(feats, 0, 2000 * sizeof(FeatureC))
start = feats start = feats
feats = _add_token(feats, 0, state.S_(0), 1.0) feats = _add_token(feats, 0, state.S_(0), 1.0)
@ -161,13 +181,13 @@ cdef class ParserNeuralNet(NeuralNet):
state.L_(state.S(0), 2)) state.L_(state.S(0), 2))
return feats - start return feats - start
cdef void _set_delta_lossC(self, weight_t* delta_loss, #cdef void _set_delta_lossC(self, weight_t* delta_loss,
const weight_t* cost, const weight_t* scores) nogil: # const weight_t* cost, const weight_t* scores) nogil:
for i in range(self.c.widths[self.c.nr_layer-1]): # for i in range(self.c.widths[self.c.nr_layer-1]):
delta_loss[i] = cost[i] # delta_loss[i] = cost[i]
cdef void _softmaxC(self, weight_t* out) nogil: #cdef void _softmaxC(self, weight_t* out) nogil:
pass # pass
cdef void dropoutC(self, FeatureC* feats, weight_t drop_prob, cdef void dropoutC(self, FeatureC* feats, weight_t drop_prob,
int nr_feat) nogil: int nr_feat) nogil:
@ -241,6 +261,37 @@ cdef inline FeatureC* _add_token(FeatureC* feats,
return feats return feats
cdef inline FeatureC* _add_characters(FeatureC* feats,
int slot, uint64_t* chars, int length, weight_t value) with gil:
nr_start_chars = 4
nr_end_chars = 4
for i in range(min(nr_start_chars, length)):
feats.i = slot
feats.key = chars[i]
feats.value = value
feats += 1
slot += 1
for _ in range(length, nr_start_chars):
feats.i = slot
feats.key = 0
feats.value = 0
feats += 1
slot += 1
for i in range(min(nr_end_chars, length)):
feats.i = slot
feats.key = chars[(length-nr_end_chars)+i]
feats.value = value
feats += 1
slot += 1
for _ in range(length, nr_start_chars):
feats.i = slot
feats.key = 0
feats.value = 0
feats += 1
slot += 1
return feats
cdef inline FeatureC* _add_subtree(FeatureC* feats, int slot, const StateC* state, int t) nogil: cdef inline FeatureC* _add_subtree(FeatureC* feats, int slot, const StateC* state, int t) nogil:
value = 1.0 value = 1.0
for i in range(state.n_R(t)): for i in range(state.n_R(t)):

View File

@ -90,6 +90,9 @@ cdef class BeamParser(Parser):
cdef Beam beam = Beam(self.moves.n_moves, self.beam_width, min_density=self.beam_density) cdef Beam beam = Beam(self.moves.n_moves, self.beam_width, min_density=self.beam_density)
beam.initialize(_init_state, length, tokens) beam.initialize(_init_state, length, tokens)
beam.check_done(_check_final_state, NULL) beam.check_done(_check_final_state, NULL)
if beam.is_done:
_cleanup(beam)
return 0
while not beam.is_done: while not beam.is_done:
self._advance_beam(beam, None, False) self._advance_beam(beam, None, False)
state = <StateClass>beam.at(0) state = <StateClass>beam.at(0)
@ -100,7 +103,7 @@ cdef class BeamParser(Parser):
def train(self, Doc tokens, GoldParse gold_parse, itn=0): def train(self, Doc tokens, GoldParse gold_parse, itn=0):
self.moves.preprocess_gold(gold_parse) self.moves.preprocess_gold(gold_parse)
cdef Beam pred = Beam(self.moves.n_moves, self.beam_width, min_density=self.beam_density) cdef Beam pred = Beam(self.moves.n_moves, self.beam_width)
pred.initialize(_init_state, tokens.length, tokens.c) pred.initialize(_init_state, tokens.length, tokens.c)
pred.check_done(_check_final_state, NULL) pred.check_done(_check_final_state, NULL)
@ -124,13 +127,15 @@ cdef class BeamParser(Parser):
elif pred._states[i].loss == 0.0: elif pred._states[i].loss == 0.0:
pred._states[i].loss = 1.0 pred._states[i].loss = 1.0
violn.check_crf(pred, gold) violn.check_crf(pred, gold)
_check_train_integrity(pred, gold, gold_parse, self.moves) assert pred.size >= 1
assert gold.size >= 1
#_check_train_integrity(pred, gold, gold_parse, self.moves)
histories = zip(violn.p_probs, violn.p_hist) + zip(violn.g_probs, violn.g_hist) histories = zip(violn.p_probs, violn.p_hist) + zip(violn.g_probs, violn.g_hist)
min_grad = 0.001 ** (itn+1) min_grad = 0.001 ** (itn+1)
histories = [(grad, hist) for grad, hist in histories if abs(grad) >= min_grad] histories = [(grad, hist) for grad, hist in histories if abs(grad) >= min_grad]
random.shuffle(histories) random.shuffle(histories)
for grad, hist in histories: for grad, hist in histories:
assert not math.isnan(grad) and not math.isinf(grad) assert not math.isnan(grad) and not math.isinf(grad), hist
self.model._update_from_history(self.moves, tokens, hist, grad) self.model._update_from_history(self.moves, tokens, hist, grad)
_cleanup(pred) _cleanup(pred)
_cleanup(gold) _cleanup(gold)
@ -155,7 +160,7 @@ cdef class BeamParser(Parser):
else: else:
for i in range(beam.size): for i in range(beam.size):
stcls = <StateClass>beam.at(i) stcls = <StateClass>beam.at(i)
if not stcls.c.is_final(): if not stcls.is_final():
nr_feat = self.model.set_featuresC(features, stcls.c) nr_feat = self.model.set_featuresC(features, stcls.c)
self.moves.set_valid(beam.is_valid[i], stcls.c) self.moves.set_valid(beam.is_valid[i], stcls.c)
self.model.set_scoresC(beam.scores[i], features, nr_feat) self.model.set_scoresC(beam.scores[i], features, nr_feat)
@ -185,12 +190,12 @@ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves)
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL: cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
cdef StateClass st = StateClass.init(<const TokenC*>tokens, length) cdef StateClass st = StateClass.init(<const TokenC*>tokens, length)
# Ensure sent_start is set to 0 throughout ## Ensure sent_start is set to 0 throughout
for i in range(st.c.length): #for i in range(st.c.length):
st.c._sent[i].sent_start = False # st.c._sent[i].sent_start = False
st.c._sent[i].l_edge = i # st.c._sent[i].l_edge = i
st.c._sent[i].r_edge = i # st.c._sent[i].r_edge = i
st.fast_forward() #st.fast_forward()
Py_INCREF(st) Py_INCREF(st)
return <void*>st return <void*>st
@ -219,7 +224,6 @@ def _check_train_integrity(Beam pred, Beam gold, GoldParse gold_parse, Transitio
continue continue
state = <StateClass>pred.at(i) state = <StateClass>pred.at(i)
if is_gold(state, gold_parse, moves.strings) == True: if is_gold(state, gold_parse, moves.strings) == True:
print("Truth")
for dep in gold_parse.orig_annot: for dep in gold_parse.orig_annot:
print(dep[1], dep[3], dep[4]) print(dep[1], dep[3], dep[4])
print("Cost", pred._states[i].loss) print("Cost", pred._states[i].loss)

View File

@ -84,7 +84,7 @@ def ParserFactory(transition_system):
cdef class Parser: cdef class Parser:
def __init__(self, StringStore strings, transition_system, model): def __init__(self, StringStore strings, transition_system, model, *args, **kwargs):
self.moves = transition_system self.moves = transition_system
self.model = model self.model = model
@ -98,17 +98,18 @@ cdef class Parser:
moves = transition_system(strings, cfg.labels) moves = transition_system(strings, cfg.labels)
if cfg.get('model') == 'neural': if cfg.get('model') == 'neural':
model = ParserNeuralNet(cfg.hidden_layers + [moves.n_moves], model = ParserNeuralNet(cfg.hyper_params['hidden_layers'] + [moves.n_moves],
update_step=cfg.update_step, eta=cfg.eta, rho=cfg.rho, update_step=cfg.hyper_params['update_step'],
noise=cfg.noise) eta=cfg.hyper_params['learn_rate'],
rho=cfg.hyper_params['L2'],
noise=cfg.hyper_params['noise'])
else: else:
model = ParserPerceptron(get_templates(cfg.feat_set), model = ParserPerceptron(get_templates(cfg.feat_set),
learn_rate=cfg.get('eta', 0.001), learn_rate=cfg.get('eta', 0.001),
l1_penalty=cfg.rho) l1_penalty=cfg.rho)
if path.exists(path.join(model_dir, 'model')): if path.exists(path.join(model_dir, 'model')):
model.load(path.join(model_dir, 'model')) model.load(path.join(model_dir, 'model'))
return cls(strings, moves, model) return cls(strings, moves, model, beam_width=cfg.get('beam_width', 1))
@classmethod @classmethod
def load(cls, pkg_or_str_or_file, vocab): def load(cls, pkg_or_str_or_file, vocab):
@ -206,16 +207,13 @@ cdef class Parser:
eg.c.nr_feat = self.model.set_featuresC(eg.c.features, stcls.c) eg.c.nr_feat = self.model.set_featuresC(eg.c.features, stcls.c)
self.model.dropoutC(eg.c.features, self.model.dropoutC(eg.c.features,
0.5, eg.c.nr_feat) 0.5, eg.c.nr_feat)
if eg.c.features[0].i == 1: if eg.c.features[0].key == 1:
eg.c.features[0].value = 1.0 eg.c.features[0].value = 1.0
#for i in range(eg.c.nr_feat):
# if eg.c.features[i].value != 0:
# self.model.apply_L1(eg.c.features[i].key)
self.model.set_scoresC(eg.c.scores, eg.c.features, eg.c.nr_feat) self.model.set_scoresC(eg.c.scores, eg.c.features, eg.c.nr_feat)
self.moves.set_costs(eg.c.is_valid, eg.c.costs, stcls, gold) self.moves.set_costs(eg.c.is_valid, eg.c.costs, stcls, gold)
action = self.moves.c[eg.guess] action = self.moves.c[eg.guess]
action.do(stcls.c, action.label) action.do(stcls.c, action.label)
loss += self.model.update(eg, loss='nll') loss += self.model.update(eg)
eg.reset() eg.reset()
return loss return loss