Restore use of Example object in parser.train

This commit is contained in:
Matthew Honnibal 2016-08-29 14:26:43 +02:00
parent 14f3da2a2e
commit 1a1b2f9174

View File

@ -29,6 +29,7 @@ from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t, idx_t
from thinc.linear.avgtron cimport AveragedPerceptron from thinc.linear.avgtron cimport AveragedPerceptron
from thinc.linalg cimport VecVec from thinc.linalg cimport VecVec
from thinc.structs cimport NeuralNetC, SparseArrayC, ExampleC from thinc.structs cimport NeuralNetC, SparseArrayC, ExampleC
from thinc.neural.nn cimport NeuralNet
from thinc.extra.eg cimport Example from thinc.extra.eg cimport Example
from preshed.maps cimport MapStruct from preshed.maps cimport MapStruct
@ -100,7 +101,9 @@ cdef class Parser:
update_step=cfg.update_step, eta=cfg.eta, rho=cfg.rho, update_step=cfg.update_step, eta=cfg.eta, rho=cfg.rho,
noise=cfg.noise) noise=cfg.noise)
else: else:
model = ParserPerceptron(get_templates(cfg.feat_set)) model = ParserPerceptron(get_templates(cfg.feat_set),
learn_rate=cfg.get('eta', 0.001),
l1_penalty=cfg.rho)
if path.exists(path.join(model_dir, 'model')): if path.exists(path.join(model_dir, 'model')):
model.load(path.join(model_dir, 'model')) model.load(path.join(model_dir, 'model'))
@ -165,14 +168,13 @@ cdef class Parser:
yield doc yield doc
cdef int parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) with gil: cdef int parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) with gil:
cdef Example py_eg = Example(nr_class=nr_class, nr_atom=CONTEXT_SIZE, nr_feat=nr_feat, cdef Example py_eg = Example(nr_class=nr_class, nr_feat=nr_feat)
widths=self.model.widths)
cdef ExampleC* eg = py_eg.c cdef ExampleC* eg = py_eg.c
state = new StateC(tokens, length) state = new StateC(tokens, length)
self.moves.initialize_state(state) self.moves.initialize_state(state)
cdef int i cdef int i
while not state.is_final(): while not state.is_final():
self.model.set_featuresC(eg, state) eg.nr_feat = self.model.set_featuresC(eg.features, state)
self.moves.set_valid(eg.is_valid, state) self.moves.set_valid(eg.is_valid, state)
self.model.set_scoresC(eg.scores, eg.features, eg.nr_feat) self.model.set_scoresC(eg.scores, eg.features, eg.nr_feat)
@ -190,25 +192,23 @@ cdef class Parser:
del state del state
return 0 return 0
def train(self, Doc tokens, GoldParse gold): def train(self, Doc tokens, GoldParse gold, itn=0):
self.moves.preprocess_gold(gold) self.moves.preprocess_gold(gold)
cdef StateClass stcls = StateClass.init(tokens.c, tokens.length) cdef StateClass stcls = StateClass.init(tokens.c, tokens.length)
self.moves.initialize_state(stcls.c) self.moves.initialize_state(stcls.c)
cdef Pool mem = Pool()
cdef Example eg = Example( cdef Example eg = Example(
nr_class=self.moves.n_moves, nr_class=self.moves.n_moves,
widths=self.model.widths,
nr_atom=CONTEXT_SIZE,
nr_feat=self.model.nr_feat) nr_feat=self.model.nr_feat)
loss = 0 loss = 0
cdef Transition action cdef Transition action
while not stcls.is_final(): while not stcls.is_final():
self.model.set_featuresC(eg.c, stcls.c) eg.c.nr_feat = self.model.set_featuresC(eg.c.features, stcls.c)
self.model.set_scoresC(eg.c.scores, eg.c.features, eg.c.nr_feat) self.model.set_scoresC(eg.c.scores, eg.c.features, eg.c.nr_feat)
self.moves.set_costs(eg.c.is_valid, eg.c.costs, stcls, gold) self.moves.set_costs(eg.c.is_valid, eg.c.costs, stcls, gold)
guess = VecVec.arg_max_if_true(eg.c.scores, eg.c.is_valid, eg.c.nr_class) for i in range(self.moves.n_moves):
assert guess >= 0 if eg.c.costs[i] < 0:
action = self.moves.c[guess] eg.c.costs[i] = 0
action = self.moves.c[eg.guess]
action.do(stcls.c, action.label) action.do(stcls.c, action.label)
loss += self.model.update(eg) loss += self.model.update(eg)
@ -273,7 +273,7 @@ cdef class StepwiseState:
def predict(self): def predict(self):
self.eg.reset() self.eg.reset()
self.parser.model.set_featuresC(self.eg.c, self.stcls.c) self.eg.c.nr_feat = self.parser.model.set_featuresC(self.eg.c.features, self.stcls.c)
self.parser.moves.set_valid(self.eg.c.is_valid, self.stcls.c) self.parser.moves.set_valid(self.eg.c.is_valid, self.stcls.c)
self.parser.model.set_scoresC(self.eg.c.scores, self.parser.model.set_scoresC(self.eg.c.scores,
self.eg.c.features, self.eg.c.nr_feat) self.eg.c.features, self.eg.c.nr_feat)