mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 10:46:29 +03:00
* Impove efficiency of dynamic oracle, making beam training faster
This commit is contained in:
parent
079dad28a7
commit
4433396005
|
@ -152,6 +152,35 @@ cdef class ArcEager(TransitionSystem):
|
|||
for i in range(self.n_moves):
|
||||
output[i] = is_valid[self.c[i].move]
|
||||
|
||||
cdef int set_costs(self, int* output, const State* s, GoldParse gold) except -1:
|
||||
cdef Transition move
|
||||
move.label = -1
|
||||
cdef int[N_MOVES] move_costs
|
||||
move_costs[SHIFT] = _shift_cost(&move, s, &gold.c)
|
||||
move_costs[REDUCE] = _reduce_cost(&move, s, &gold.c)
|
||||
move_costs[LEFT] = _left_cost(&move, s, &gold.c)
|
||||
move_costs[RIGHT] = _right_cost(&move, s, &gold.c)
|
||||
move_costs[BREAK] = _break_cost(&move, s, &gold.c)
|
||||
move_costs[CONSTITUENT] = _constituent_cost(&move, s, &gold.c)
|
||||
move_costs[ADJUST] = _adjust_cost(&move, s, &gold.c)
|
||||
|
||||
cdef int i, label
|
||||
cdef int* labels = gold.c.labels
|
||||
cdef int* heads = gold.c.heads
|
||||
for i in range(self.n_moves):
|
||||
move = self.c[i]
|
||||
output[i] = move_costs[move.move]
|
||||
if output[i] == 0:
|
||||
label = -1
|
||||
if move.move == RIGHT and heads[s.i] == s.stack[0]:
|
||||
label = labels[s.i]
|
||||
if move.move == LEFT and heads[s.stack[0]] == s.i:
|
||||
label = labels[s.stack[0]]
|
||||
elif move.move == LEFT and at_eol(s) and (_can_reduce(s) or _can_break(s)):
|
||||
label = labels[s.stack[0]]
|
||||
output[i] += move.label != label and label != -1
|
||||
|
||||
|
||||
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
||||
cdef bint[N_MOVES] is_valid
|
||||
is_valid[SHIFT] = _can_shift(s)
|
||||
|
@ -234,7 +263,7 @@ cdef int _right_cost(const Transition* self, const State* s, GoldParseC* gold) e
|
|||
return 9000
|
||||
cost = 0
|
||||
if gold.heads[s.i] == s.stack[0]:
|
||||
cost += self.label != gold.labels[s.i]
|
||||
cost += self.label != -1 and self.label != gold.labels[s.i]
|
||||
return cost
|
||||
# This indicates missing head
|
||||
if gold.labels[s.i] != -1:
|
||||
|
@ -249,7 +278,7 @@ cdef int _left_cost(const Transition* self, const State* s, GoldParseC* gold) ex
|
|||
return 9000
|
||||
cost = 0
|
||||
if gold.heads[s.stack[0]] == s.i:
|
||||
cost += self.label != gold.labels[s.stack[0]]
|
||||
cost += self.label != -1 and self.label != gold.labels[s.stack[0]]
|
||||
return cost
|
||||
# If we're at EOL, then the left arc will add an arc to ROOT.
|
||||
elif at_eol(s):
|
||||
|
@ -259,7 +288,7 @@ cdef int _left_cost(const Transition* self, const State* s, GoldParseC* gold) ex
|
|||
if _can_reduce(s) or _can_break(s):
|
||||
cost += gold.heads[s.stack[0]] != s.stack[0]
|
||||
# Are we labelling correctly?
|
||||
cost += self.label != gold.labels[s.stack[0]]
|
||||
cost += self.label != -1 and self.label != gold.labels[s.stack[0]]
|
||||
return cost
|
||||
|
||||
cost += head_in_buffer(s, s.stack[0], gold.heads)
|
||||
|
|
|
@ -153,10 +153,12 @@ cdef class Parser:
|
|||
self._advance_beam(pred, gold_parse, False)
|
||||
self._advance_beam(gold, gold_parse, True)
|
||||
violn.check(pred, gold)
|
||||
counts = {}
|
||||
if pred.loss >= 1:
|
||||
counts = {clas: {} for clas in range(self.model.n_classes)}
|
||||
self._count_feats(counts, tokens, violn.g_hist, 1)
|
||||
self._count_feats(counts, tokens, violn.p_hist, -1)
|
||||
else:
|
||||
counts = {}
|
||||
self.model._model.update(counts)
|
||||
return pred.loss
|
||||
|
||||
|
@ -171,20 +173,14 @@ cdef class Parser:
|
|||
fill_context(context, state)
|
||||
self.model.set_scores(beam.scores[i], context)
|
||||
self.moves.set_valid(beam.is_valid[i], state)
|
||||
|
||||
if follow_gold:
|
||||
|
||||
if gold is not None:
|
||||
for i in range(beam.size):
|
||||
state = <State*>beam.at(i)
|
||||
for j in range(self.moves.n_moves):
|
||||
move = &self.moves.c[j]
|
||||
beam.costs[i][j] = move.get_cost(move, state, &gold.c)
|
||||
beam.is_valid[i][j] = beam.costs[i][j] == 0
|
||||
elif gold is not None:
|
||||
for i in range(beam.size):
|
||||
state = <State*>beam.at(i)
|
||||
for j in range(self.moves.n_moves):
|
||||
move = &self.moves.c[j]
|
||||
beam.costs[i][j] = move.get_cost(move, state, &gold.c)
|
||||
self.moves.set_costs(beam.costs[i], state, gold)
|
||||
if follow_gold:
|
||||
for j in range(self.moves.n_moves):
|
||||
beam.is_valid[i][j] = beam.costs[i][j] == 0
|
||||
beam.advance(_transition_state, <void*>self.moves.c)
|
||||
state = <State*>beam.at(0)
|
||||
if state.sent[state.i].sent_end:
|
||||
|
@ -204,7 +200,7 @@ cdef class Parser:
|
|||
break
|
||||
fill_context(context, state)
|
||||
feats = self.model._extractor.get_feats(context, &n_feats)
|
||||
count_feats(counts.setdefault(clas, {}), feats, n_feats, inc)
|
||||
count_feats(counts[clas], feats, n_feats, inc)
|
||||
self.moves.c[clas].do(&self.moves.c[clas], state)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user