mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Restore use of Example object in parser.train
This commit is contained in:
parent
14f3da2a2e
commit
1a1b2f9174
|
@ -29,6 +29,7 @@ from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t, idx_t
|
|||
from thinc.linear.avgtron cimport AveragedPerceptron
|
||||
from thinc.linalg cimport VecVec
|
||||
from thinc.structs cimport NeuralNetC, SparseArrayC, ExampleC
|
||||
from thinc.neural.nn cimport NeuralNet
|
||||
from thinc.extra.eg cimport Example
|
||||
|
||||
from preshed.maps cimport MapStruct
|
||||
|
@ -100,7 +101,9 @@ cdef class Parser:
|
|||
update_step=cfg.update_step, eta=cfg.eta, rho=cfg.rho,
|
||||
noise=cfg.noise)
|
||||
else:
|
||||
model = ParserPerceptron(get_templates(cfg.feat_set))
|
||||
model = ParserPerceptron(get_templates(cfg.feat_set),
|
||||
learn_rate=cfg.get('eta', 0.001),
|
||||
l1_penalty=cfg.rho)
|
||||
|
||||
if path.exists(path.join(model_dir, 'model')):
|
||||
model.load(path.join(model_dir, 'model'))
|
||||
|
@ -165,14 +168,13 @@ cdef class Parser:
|
|||
yield doc
|
||||
|
||||
cdef int parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) with gil:
|
||||
cdef Example py_eg = Example(nr_class=nr_class, nr_atom=CONTEXT_SIZE, nr_feat=nr_feat,
|
||||
widths=self.model.widths)
|
||||
cdef Example py_eg = Example(nr_class=nr_class, nr_feat=nr_feat)
|
||||
cdef ExampleC* eg = py_eg.c
|
||||
state = new StateC(tokens, length)
|
||||
self.moves.initialize_state(state)
|
||||
cdef int i
|
||||
while not state.is_final():
|
||||
self.model.set_featuresC(eg, state)
|
||||
eg.nr_feat = self.model.set_featuresC(eg.features, state)
|
||||
self.moves.set_valid(eg.is_valid, state)
|
||||
self.model.set_scoresC(eg.scores, eg.features, eg.nr_feat)
|
||||
|
||||
|
@ -190,25 +192,23 @@ cdef class Parser:
|
|||
del state
|
||||
return 0
|
||||
|
||||
def train(self, Doc tokens, GoldParse gold):
|
||||
def train(self, Doc tokens, GoldParse gold, itn=0):
|
||||
self.moves.preprocess_gold(gold)
|
||||
cdef StateClass stcls = StateClass.init(tokens.c, tokens.length)
|
||||
self.moves.initialize_state(stcls.c)
|
||||
cdef Pool mem = Pool()
|
||||
cdef Example eg = Example(
|
||||
nr_class=self.moves.n_moves,
|
||||
widths=self.model.widths,
|
||||
nr_atom=CONTEXT_SIZE,
|
||||
nr_feat=self.model.nr_feat)
|
||||
loss = 0
|
||||
cdef Transition action
|
||||
while not stcls.is_final():
|
||||
self.model.set_featuresC(eg.c, stcls.c)
|
||||
eg.c.nr_feat = self.model.set_featuresC(eg.c.features, stcls.c)
|
||||
self.model.set_scoresC(eg.c.scores, eg.c.features, eg.c.nr_feat)
|
||||
self.moves.set_costs(eg.c.is_valid, eg.c.costs, stcls, gold)
|
||||
guess = VecVec.arg_max_if_true(eg.c.scores, eg.c.is_valid, eg.c.nr_class)
|
||||
assert guess >= 0
|
||||
action = self.moves.c[guess]
|
||||
for i in range(self.moves.n_moves):
|
||||
if eg.c.costs[i] < 0:
|
||||
eg.c.costs[i] = 0
|
||||
action = self.moves.c[eg.guess]
|
||||
action.do(stcls.c, action.label)
|
||||
|
||||
loss += self.model.update(eg)
|
||||
|
@ -273,7 +273,7 @@ cdef class StepwiseState:
|
|||
|
||||
def predict(self):
|
||||
self.eg.reset()
|
||||
self.parser.model.set_featuresC(self.eg.c, self.stcls.c)
|
||||
self.eg.c.nr_feat = self.parser.model.set_featuresC(self.eg.c.features, self.stcls.c)
|
||||
self.parser.moves.set_valid(self.eg.c.is_valid, self.stcls.c)
|
||||
self.parser.model.set_scoresC(self.eg.c.scores,
|
||||
self.eg.c.features, self.eg.c.nr_feat)
|
||||
|
|
Loading…
Reference in New Issue
Block a user