* Impove efficiency of dynamic oracle, making beam training faster

This commit is contained in:
Matthew Honnibal 2015-06-04 21:15:14 +02:00
parent 079dad28a7
commit 4433396005
2 changed files with 42 additions and 17 deletions

View File

@ -152,6 +152,35 @@ cdef class ArcEager(TransitionSystem):
for i in range(self.n_moves):
output[i] = is_valid[self.c[i].move]
cdef int set_costs(self, int* output, const State* s, GoldParse gold) except -1:
cdef Transition move
move.label = -1
cdef int[N_MOVES] move_costs
move_costs[SHIFT] = _shift_cost(&move, s, &gold.c)
move_costs[REDUCE] = _reduce_cost(&move, s, &gold.c)
move_costs[LEFT] = _left_cost(&move, s, &gold.c)
move_costs[RIGHT] = _right_cost(&move, s, &gold.c)
move_costs[BREAK] = _break_cost(&move, s, &gold.c)
move_costs[CONSTITUENT] = _constituent_cost(&move, s, &gold.c)
move_costs[ADJUST] = _adjust_cost(&move, s, &gold.c)
cdef int i, label
cdef int* labels = gold.c.labels
cdef int* heads = gold.c.heads
for i in range(self.n_moves):
move = self.c[i]
output[i] = move_costs[move.move]
if output[i] == 0:
label = -1
if move.move == RIGHT and heads[s.i] == s.stack[0]:
label = labels[s.i]
if move.move == LEFT and heads[s.stack[0]] == s.i:
label = labels[s.stack[0]]
elif move.move == LEFT and at_eol(s) and (_can_reduce(s) or _can_break(s)):
label = labels[s.stack[0]]
output[i] += move.label != label and label != -1
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
cdef bint[N_MOVES] is_valid
is_valid[SHIFT] = _can_shift(s)
@ -234,7 +263,7 @@ cdef int _right_cost(const Transition* self, const State* s, GoldParseC* gold) e
return 9000
cost = 0
if gold.heads[s.i] == s.stack[0]:
cost += self.label != gold.labels[s.i]
cost += self.label != -1 and self.label != gold.labels[s.i]
return cost
# This indicates missing head
if gold.labels[s.i] != -1:
@ -249,7 +278,7 @@ cdef int _left_cost(const Transition* self, const State* s, GoldParseC* gold) ex
return 9000
cost = 0
if gold.heads[s.stack[0]] == s.i:
cost += self.label != gold.labels[s.stack[0]]
cost += self.label != -1 and self.label != gold.labels[s.stack[0]]
return cost
# If we're at EOL, then the left arc will add an arc to ROOT.
elif at_eol(s):
@ -259,7 +288,7 @@ cdef int _left_cost(const Transition* self, const State* s, GoldParseC* gold) ex
if _can_reduce(s) or _can_break(s):
cost += gold.heads[s.stack[0]] != s.stack[0]
# Are we labelling correctly?
cost += self.label != gold.labels[s.stack[0]]
cost += self.label != -1 and self.label != gold.labels[s.stack[0]]
return cost
cost += head_in_buffer(s, s.stack[0], gold.heads)

View File

@ -153,10 +153,12 @@ cdef class Parser:
self._advance_beam(pred, gold_parse, False)
self._advance_beam(gold, gold_parse, True)
violn.check(pred, gold)
counts = {}
if pred.loss >= 1:
counts = {clas: {} for clas in range(self.model.n_classes)}
self._count_feats(counts, tokens, violn.g_hist, 1)
self._count_feats(counts, tokens, violn.p_hist, -1)
else:
counts = {}
self.model._model.update(counts)
return pred.loss
@ -171,20 +173,14 @@ cdef class Parser:
fill_context(context, state)
self.model.set_scores(beam.scores[i], context)
self.moves.set_valid(beam.is_valid[i], state)
if follow_gold:
if gold is not None:
for i in range(beam.size):
state = <State*>beam.at(i)
for j in range(self.moves.n_moves):
move = &self.moves.c[j]
beam.costs[i][j] = move.get_cost(move, state, &gold.c)
beam.is_valid[i][j] = beam.costs[i][j] == 0
elif gold is not None:
for i in range(beam.size):
state = <State*>beam.at(i)
for j in range(self.moves.n_moves):
move = &self.moves.c[j]
beam.costs[i][j] = move.get_cost(move, state, &gold.c)
self.moves.set_costs(beam.costs[i], state, gold)
if follow_gold:
for j in range(self.moves.n_moves):
beam.is_valid[i][j] = beam.costs[i][j] == 0
beam.advance(_transition_state, <void*>self.moves.c)
state = <State*>beam.at(0)
if state.sent[state.i].sent_end:
@ -204,7 +200,7 @@ cdef class Parser:
break
fill_context(context, state)
feats = self.model._extractor.get_feats(context, &n_feats)
count_feats(counts.setdefault(clas, {}), feats, n_feats, inc)
count_feats(counts[clas], feats, n_feats, inc)
self.moves.c[clas].do(&self.moves.c[clas], state)