mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-29 17:33:10 +03:00
* Impove efficiency of dynamic oracle, making beam training faster
This commit is contained in:
parent
079dad28a7
commit
4433396005
|
@ -152,6 +152,35 @@ cdef class ArcEager(TransitionSystem):
|
||||||
for i in range(self.n_moves):
|
for i in range(self.n_moves):
|
||||||
output[i] = is_valid[self.c[i].move]
|
output[i] = is_valid[self.c[i].move]
|
||||||
|
|
||||||
|
cdef int set_costs(self, int* output, const State* s, GoldParse gold) except -1:
|
||||||
|
cdef Transition move
|
||||||
|
move.label = -1
|
||||||
|
cdef int[N_MOVES] move_costs
|
||||||
|
move_costs[SHIFT] = _shift_cost(&move, s, &gold.c)
|
||||||
|
move_costs[REDUCE] = _reduce_cost(&move, s, &gold.c)
|
||||||
|
move_costs[LEFT] = _left_cost(&move, s, &gold.c)
|
||||||
|
move_costs[RIGHT] = _right_cost(&move, s, &gold.c)
|
||||||
|
move_costs[BREAK] = _break_cost(&move, s, &gold.c)
|
||||||
|
move_costs[CONSTITUENT] = _constituent_cost(&move, s, &gold.c)
|
||||||
|
move_costs[ADJUST] = _adjust_cost(&move, s, &gold.c)
|
||||||
|
|
||||||
|
cdef int i, label
|
||||||
|
cdef int* labels = gold.c.labels
|
||||||
|
cdef int* heads = gold.c.heads
|
||||||
|
for i in range(self.n_moves):
|
||||||
|
move = self.c[i]
|
||||||
|
output[i] = move_costs[move.move]
|
||||||
|
if output[i] == 0:
|
||||||
|
label = -1
|
||||||
|
if move.move == RIGHT and heads[s.i] == s.stack[0]:
|
||||||
|
label = labels[s.i]
|
||||||
|
if move.move == LEFT and heads[s.stack[0]] == s.i:
|
||||||
|
label = labels[s.stack[0]]
|
||||||
|
elif move.move == LEFT and at_eol(s) and (_can_reduce(s) or _can_break(s)):
|
||||||
|
label = labels[s.stack[0]]
|
||||||
|
output[i] += move.label != label and label != -1
|
||||||
|
|
||||||
|
|
||||||
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
||||||
cdef bint[N_MOVES] is_valid
|
cdef bint[N_MOVES] is_valid
|
||||||
is_valid[SHIFT] = _can_shift(s)
|
is_valid[SHIFT] = _can_shift(s)
|
||||||
|
@ -234,7 +263,7 @@ cdef int _right_cost(const Transition* self, const State* s, GoldParseC* gold) e
|
||||||
return 9000
|
return 9000
|
||||||
cost = 0
|
cost = 0
|
||||||
if gold.heads[s.i] == s.stack[0]:
|
if gold.heads[s.i] == s.stack[0]:
|
||||||
cost += self.label != gold.labels[s.i]
|
cost += self.label != -1 and self.label != gold.labels[s.i]
|
||||||
return cost
|
return cost
|
||||||
# This indicates missing head
|
# This indicates missing head
|
||||||
if gold.labels[s.i] != -1:
|
if gold.labels[s.i] != -1:
|
||||||
|
@ -249,7 +278,7 @@ cdef int _left_cost(const Transition* self, const State* s, GoldParseC* gold) ex
|
||||||
return 9000
|
return 9000
|
||||||
cost = 0
|
cost = 0
|
||||||
if gold.heads[s.stack[0]] == s.i:
|
if gold.heads[s.stack[0]] == s.i:
|
||||||
cost += self.label != gold.labels[s.stack[0]]
|
cost += self.label != -1 and self.label != gold.labels[s.stack[0]]
|
||||||
return cost
|
return cost
|
||||||
# If we're at EOL, then the left arc will add an arc to ROOT.
|
# If we're at EOL, then the left arc will add an arc to ROOT.
|
||||||
elif at_eol(s):
|
elif at_eol(s):
|
||||||
|
@ -259,7 +288,7 @@ cdef int _left_cost(const Transition* self, const State* s, GoldParseC* gold) ex
|
||||||
if _can_reduce(s) or _can_break(s):
|
if _can_reduce(s) or _can_break(s):
|
||||||
cost += gold.heads[s.stack[0]] != s.stack[0]
|
cost += gold.heads[s.stack[0]] != s.stack[0]
|
||||||
# Are we labelling correctly?
|
# Are we labelling correctly?
|
||||||
cost += self.label != gold.labels[s.stack[0]]
|
cost += self.label != -1 and self.label != gold.labels[s.stack[0]]
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
cost += head_in_buffer(s, s.stack[0], gold.heads)
|
cost += head_in_buffer(s, s.stack[0], gold.heads)
|
||||||
|
|
|
@ -153,10 +153,12 @@ cdef class Parser:
|
||||||
self._advance_beam(pred, gold_parse, False)
|
self._advance_beam(pred, gold_parse, False)
|
||||||
self._advance_beam(gold, gold_parse, True)
|
self._advance_beam(gold, gold_parse, True)
|
||||||
violn.check(pred, gold)
|
violn.check(pred, gold)
|
||||||
counts = {}
|
|
||||||
if pred.loss >= 1:
|
if pred.loss >= 1:
|
||||||
|
counts = {clas: {} for clas in range(self.model.n_classes)}
|
||||||
self._count_feats(counts, tokens, violn.g_hist, 1)
|
self._count_feats(counts, tokens, violn.g_hist, 1)
|
||||||
self._count_feats(counts, tokens, violn.p_hist, -1)
|
self._count_feats(counts, tokens, violn.p_hist, -1)
|
||||||
|
else:
|
||||||
|
counts = {}
|
||||||
self.model._model.update(counts)
|
self.model._model.update(counts)
|
||||||
return pred.loss
|
return pred.loss
|
||||||
|
|
||||||
|
@ -172,19 +174,13 @@ cdef class Parser:
|
||||||
self.model.set_scores(beam.scores[i], context)
|
self.model.set_scores(beam.scores[i], context)
|
||||||
self.moves.set_valid(beam.is_valid[i], state)
|
self.moves.set_valid(beam.is_valid[i], state)
|
||||||
|
|
||||||
|
if gold is not None:
|
||||||
|
for i in range(beam.size):
|
||||||
|
state = <State*>beam.at(i)
|
||||||
|
self.moves.set_costs(beam.costs[i], state, gold)
|
||||||
if follow_gold:
|
if follow_gold:
|
||||||
for i in range(beam.size):
|
|
||||||
state = <State*>beam.at(i)
|
|
||||||
for j in range(self.moves.n_moves):
|
for j in range(self.moves.n_moves):
|
||||||
move = &self.moves.c[j]
|
|
||||||
beam.costs[i][j] = move.get_cost(move, state, &gold.c)
|
|
||||||
beam.is_valid[i][j] = beam.costs[i][j] == 0
|
beam.is_valid[i][j] = beam.costs[i][j] == 0
|
||||||
elif gold is not None:
|
|
||||||
for i in range(beam.size):
|
|
||||||
state = <State*>beam.at(i)
|
|
||||||
for j in range(self.moves.n_moves):
|
|
||||||
move = &self.moves.c[j]
|
|
||||||
beam.costs[i][j] = move.get_cost(move, state, &gold.c)
|
|
||||||
beam.advance(_transition_state, <void*>self.moves.c)
|
beam.advance(_transition_state, <void*>self.moves.c)
|
||||||
state = <State*>beam.at(0)
|
state = <State*>beam.at(0)
|
||||||
if state.sent[state.i].sent_end:
|
if state.sent[state.i].sent_end:
|
||||||
|
@ -204,7 +200,7 @@ cdef class Parser:
|
||||||
break
|
break
|
||||||
fill_context(context, state)
|
fill_context(context, state)
|
||||||
feats = self.model._extractor.get_feats(context, &n_feats)
|
feats = self.model._extractor.get_feats(context, &n_feats)
|
||||||
count_feats(counts.setdefault(clas, {}), feats, n_feats, inc)
|
count_feats(counts[clas], feats, n_feats, inc)
|
||||||
self.moves.c[clas].do(&self.moves.c[clas], state)
|
self.moves.c[clas].do(&self.moves.c[clas], state)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user