* Rejig parser interface to use new thinc.api.Example class, in prep of theano model. Comment out beam search

This commit is contained in:
Matthew Honnibal 2015-06-26 06:25:36 +02:00
parent 886100e1a2
commit 6896455884
7 changed files with 132 additions and 178 deletions

View File

@ -10,6 +10,7 @@ import cython
import numpy.random
from thinc.features cimport Feature, count_feats
from thinc.api cimport Example
cdef int arg_max(const weight_t* scores, const int n_classes) nogil:
@ -23,6 +24,30 @@ cdef int arg_max(const weight_t* scores, const int n_classes) nogil:
return best
cdef int arg_max_if_true(const weight_t* scores, const bint* is_valid,
const int n_classes) nogil:
cdef int i
cdef int best = 0
cdef weight_t mode = -900000
for i in range(n_classes):
if is_valid[i] and scores[i] > mode:
mode = scores[i]
best = i
return best
cdef int arg_max_if_zero(const weight_t* scores, const int* costs,
const int n_classes) nogil:
cdef int i
cdef int best = 0
cdef weight_t mode = -900000
for i in range(n_classes):
if costs[i] == 0 and scores[i] > mode:
mode = scores[i]
best = i
return best
cdef class Model:
def __init__(self, n_classes, templates, model_loc=None):
if model_loc is not None and path.isdir(model_loc):
@ -34,6 +59,17 @@ cdef class Model:
if self.model_loc and path.exists(self.model_loc):
self._model.load(self.model_loc, freq_thresh=0)
def predict(self, Example eg):
self.set_scores(eg.scores, eg.atoms)
eg.guess = arg_max_if_true(eg.scores, eg.is_valid, self.n_classes)
def train(self, Example eg):
self.set_scores(eg.scores, eg.atoms)
eg.guess = arg_max_if_true(eg.scores, eg.is_valid, self.n_classes)
eg.best = arg_max_if_zero(eg.scores, eg.costs, self.n_classes)
eg.cost = eg.costs[eg.guess]
self.update(eg.atoms, eg.guess, eg.best, eg.cost)
cdef const weight_t* score(self, atom_t* context) except NULL:
cdef int n_feats
feats = self._extractor.get_feats(context, &n_feats)

View File

@ -398,7 +398,8 @@ cdef class ArcEager(TransitionSystem):
n_valid += output[i]
assert n_valid >= 1
cdef int set_costs(self, int* output, StateClass stcls, GoldParse gold) except -1:
cdef int set_costs(self, bint* is_valid, int* costs,
StateClass stcls, GoldParse gold) except -1:
cdef int i, move, label
cdef label_cost_func_t[N_MOVES] label_cost_funcs
cdef move_cost_func_t[N_MOVES] move_cost_funcs
@ -423,30 +424,14 @@ cdef class ArcEager(TransitionSystem):
n_gold = 0
for i in range(self.n_moves):
if self.c[i].is_valid(stcls, self.c[i].label):
is_valid[i] = True
move = self.c[i].move
label = self.c[i].label
if move_costs[move] == -1:
move_costs[move] = move_cost_funcs[move](stcls, &gold.c)
output[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label)
n_gold += output[i] == 0
costs[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label)
n_gold += costs[i] == 0
else:
output[i] = 9000
is_valid[i] = False
costs[i] = 9000
assert n_gold >= 1
cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *:
cdef bint[N_MOVES] is_valid
is_valid[SHIFT] = Shift.is_valid(stcls, -1)
is_valid[REDUCE] = Reduce.is_valid(stcls, -1)
is_valid[LEFT] = LeftArc.is_valid(stcls, -1)
is_valid[RIGHT] = RightArc.is_valid(stcls, -1)
is_valid[BREAK] = Break.is_valid(stcls, -1)
cdef Transition best
cdef weight_t score = MIN_SCORE
cdef int i
for i in range(self.n_moves):
if scores[i] > score and is_valid[self.c[i].move]:
best = self.c[i]
score = scores[i]
assert best.clas < self.n_moves
assert score > MIN_SCORE, (stcls.stack_depth(), stcls.buffer_length(), stcls.is_final(), stcls._b_i, stcls.length)
return best

View File

@ -128,27 +128,6 @@ cdef class BiluoPushDown(TransitionSystem):
raise Exception(move)
return t
cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *:
cdef int best = -1
cdef weight_t score = -90000
cdef const Transition* m
cdef int i
for i in range(self.n_moves):
m = &self.c[i]
if m.is_valid(stcls, m.label) and scores[i] > score:
best = i
score = scores[i]
assert best >= 0
cdef Transition t = self.c[best]
t.score = score
return t
cdef int set_valid(self, bint* output, StateClass stcls) except -1:
cdef int i
for i in range(self.n_moves):
m = &self.c[i]
output[i] = m.is_valid(stcls, m.label)
cdef class Missing:
@staticmethod

View File

@ -11,6 +11,3 @@ cdef class Parser:
cdef readonly object cfg
cdef readonly Model model
cdef readonly TransitionSystem moves
cdef int _greedy_parse(self, Tokens tokens) except -1
cdef int _beam_parse(self, Tokens tokens) except -1

View File

@ -20,17 +20,10 @@ from cymem.cymem cimport Pool, Address
from murmurhash.mrmr cimport hash64
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t
from util import Config
from thinc.features cimport Extractor
from thinc.features cimport Feature
from thinc.features cimport count_feats
from thinc.api cimport Example
from thinc.learner cimport LinearModel
from thinc.search cimport Beam
from thinc.search cimport MaxViolation
from ..tokens cimport Tokens, TokenC
from ..strings cimport StringStore
@ -73,35 +66,86 @@ cdef class Parser:
self.model = Model(self.moves.n_moves, templates, model_dir)
def __call__(self, Tokens tokens):
if self.cfg.get('beam_width', 1) < 1:
self._greedy_parse(tokens)
else:
self._beam_parse(tokens)
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
self.moves.initialize_state(stcls)
cdef Example eg = Example(self.model.n_classes, CONTEXT_SIZE)
while not stcls.is_final():
eg.wipe()
fill_context(eg.atoms, stcls)
self.moves.set_valid(eg.is_valid, stcls)
self.model.predict(eg)
self.moves.c[eg.guess].do(stcls, self.moves.c[eg.guess].label)
self.moves.finalize_state(stcls)
tokens.set_parse(stcls._sent)
def train(self, Tokens tokens, GoldParse gold):
self.moves.preprocess_gold(gold)
if self.cfg.beam_width < 1:
return self._greedy_train(tokens, gold)
else:
return self._beam_train(tokens, gold)
cdef int _greedy_parse(self, Tokens tokens) except -1:
cdef atom_t[CONTEXT_SIZE] context
cdef int n_feats
cdef Pool mem = Pool()
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
self.moves.initialize_state(stcls)
cdef Transition guess
words = [w.orth_ for w in tokens]
cdef Example eg = Example(self.model.n_classes, CONTEXT_SIZE)
cdef int cost = 0
while not stcls.is_final():
fill_context(context, stcls)
scores = self.model.score(context)
guess = self.moves.best_valid(scores, stcls)
#print self.moves.move_name(guess.move, guess.label), stcls.print_state(words)
guess.do(stcls, guess.label)
assert stcls._s_i >= 0
self.moves.finalize_state(stcls)
tokens.set_parse(stcls._sent)
eg.wipe()
fill_context(eg.atoms, stcls)
self.moves.set_costs(eg.is_valid, eg.costs, stcls, gold)
self.model.train(eg)
self.moves.c[eg.guess].do(stcls, self.moves.c[eg.guess].label)
cost += eg.cost
return cost
# These are passed as callbacks to thinc.search.Beam
"""
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
dest = <StateClass>_dest
src = <StateClass>_src
moves = <const Transition*>_moves
dest.clone(src)
moves[clas].do(dest, moves[clas].label)
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
cdef StateClass st = StateClass.init(<const TokenC*>tokens, length)
st.fast_forward()
Py_INCREF(st)
return <void*>st
cdef int _check_final_state(void* _state, void* extra_args) except -1:
return (<StateClass>_state).is_final()
def _cleanup(Beam beam):
for i in range(beam.width):
Py_XDECREF(<PyObject*>beam._states[i].content)
Py_XDECREF(<PyObject*>beam._parents[i].content)
cdef hash_t _hash_state(void* _state, void* _) except 0:
return <hash_t>_state
#state = <const State*>_state
#cdef atom_t[10] rep
#rep[0] = state.stack[0] if state.stack_len >= 1 else 0
#rep[1] = state.stack[-1] if state.stack_len >= 2 else 0
#rep[2] = state.stack[-2] if state.stack_len >= 3 else 0
#rep[3] = state.i
#rep[4] = state.sent[state.stack[0]].l_kids if state.stack_len >= 1 else 0
#rep[5] = state.sent[state.stack[0]].r_kids if state.stack_len >= 1 else 0
#rep[6] = state.sent[state.stack[0]].dep if state.stack_len >= 1 else 0
#rep[7] = state.sent[state.stack[-1]].dep if state.stack_len >= 2 else 0
#if get_left(state, get_n0(state), 1) != NULL:
# rep[8] = get_left(state, get_n0(state), 1).dep
#else:
# rep[8] = 0
#rep[9] = state.sent[state.i].l_kids
#return hash64(rep, sizeof(atom_t) * 10, 0)
cdef int _beam_parse(self, Tokens tokens) except -1:
cdef Beam beam = Beam(self.moves.n_moves, self.cfg.beam_width)
@ -115,30 +159,6 @@ cdef class Parser:
tokens.set_parse(state._sent)
_cleanup(beam)
def _greedy_train(self, Tokens tokens, GoldParse gold):
cdef Pool mem = Pool()
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
self.moves.initialize_state(stcls)
cdef int cost
cdef const Feature* feats
cdef const weight_t* scores
cdef Transition guess
cdef Transition best
cdef atom_t[CONTEXT_SIZE] context
loss = 0
words = [w.orth_ for w in tokens]
history = []
while not stcls.is_final():
fill_context(context, stcls)
scores = self.model.score(context)
guess = self.moves.best_valid(scores, stcls)
best = self.moves.best_gold(scores, stcls, gold)
cost = guess.get_cost(stcls, &gold.c, guess.label)
self.model.update(context, guess.clas, best.clas, cost)
guess.do(stcls, guess.label)
loss += cost
return loss
def _beam_train(self, Tokens tokens, GoldParse gold_parse):
cdef Beam pred = Beam(self.moves.n_moves, self.cfg.beam_width)
@ -201,50 +221,4 @@ cdef class Parser:
count_feats(counts[clas], feats, n_feats, inc)
self.moves.c[clas].do(stcls, self.moves.c[clas].label)
# These are passed as callbacks to thinc.search.Beam
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
dest = <StateClass>_dest
src = <StateClass>_src
moves = <const Transition*>_moves
dest.clone(src)
moves[clas].do(dest, moves[clas].label)
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
cdef StateClass st = StateClass.init(<const TokenC*>tokens, length)
st.fast_forward()
Py_INCREF(st)
return <void*>st
cdef int _check_final_state(void* _state, void* extra_args) except -1:
return (<StateClass>_state).is_final()
def _cleanup(Beam beam):
for i in range(beam.width):
Py_XDECREF(<PyObject*>beam._states[i].content)
Py_XDECREF(<PyObject*>beam._parents[i].content)
cdef hash_t _hash_state(void* _state, void* _) except 0:
return <hash_t>_state
#state = <const State*>_state
#cdef atom_t[10] rep
#rep[0] = state.stack[0] if state.stack_len >= 1 else 0
#rep[1] = state.stack[-1] if state.stack_len >= 2 else 0
#rep[2] = state.stack[-2] if state.stack_len >= 3 else 0
#rep[3] = state.i
#rep[4] = state.sent[state.stack[0]].l_kids if state.stack_len >= 1 else 0
#rep[5] = state.sent[state.stack[0]].r_kids if state.stack_len >= 1 else 0
#rep[6] = state.sent[state.stack[0]].dep if state.stack_len >= 1 else 0
#rep[7] = state.sent[state.stack[-1]].dep if state.stack_len >= 2 else 0
#if get_left(state, get_n0(state), 1) != NULL:
# rep[8] = get_left(state, get_n0(state), 1).dep
#else:
# rep[8] = 0
#rep[9] = state.sent[state.i].l_kids
#return hash64(rep, sizeof(atom_t) * 10, 0)
"""

View File

@ -46,9 +46,5 @@ cdef class TransitionSystem:
cdef int set_valid(self, bint* output, StateClass state) except -1
cdef int set_costs(self, int* output, StateClass state, GoldParse gold) except -1
cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *
cdef Transition best_gold(self, const weight_t* scores, StateClass state,
GoldParse gold) except *
cdef int set_costs(self, bint* is_valid, int* costs,
StateClass state, GoldParse gold) except -1

View File

@ -43,30 +43,17 @@ cdef class TransitionSystem:
cdef Transition init_transition(self, int clas, int move, int label) except *:
raise NotImplementedError
cdef Transition best_valid(self, const weight_t* scores, StateClass s) except *:
raise NotImplementedError
cdef int set_valid(self, bint* output, StateClass state) except -1:
raise NotImplementedError
cdef int set_costs(self, int* output, StateClass stcls, GoldParse gold) except -1:
cdef int set_valid(self, bint* is_valid, StateClass stcls) except -1:
cdef int i
for i in range(self.n_moves):
if self.c[i].is_valid(stcls, self.c[i].label):
output[i] = self.c[i].get_cost(stcls, &gold.c, self.c[i].label)
is_valid[i] = self.c[i].is_valid(stcls, self.c[i].label)
cdef int set_costs(self, bint* is_valid, int* costs,
StateClass stcls, GoldParse gold) except -1:
cdef int i
self.set_valid(is_valid, stcls)
for i in range(self.n_moves):
if is_valid[i]:
costs[i] = self.c[i].get_cost(stcls, &gold.c, self.c[i].label)
else:
output[i] = 9000
cdef Transition best_gold(self, const weight_t* scores, StateClass stcls,
GoldParse gold) except *:
cdef Transition best
cdef weight_t score = MIN_SCORE
cdef int i
for i in range(self.n_moves):
if self.c[i].is_valid(stcls, self.c[i].label):
cost = self.c[i].get_cost(stcls, &gold.c, self.c[i].label)
if scores[i] > score and cost == 0:
best = self.c[i]
score = scores[i]
assert score > MIN_SCORE
return best
costs[i] = 9000