diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index dc7a96777..ff480de40 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -152,6 +152,35 @@ cdef class ArcEager(TransitionSystem): for i in range(self.n_moves): output[i] = is_valid[self.c[i].move] + cdef int set_costs(self, int* output, const State* s, GoldParse gold) except -1: + cdef Transition move + move.label = -1 + cdef int[N_MOVES] move_costs + move_costs[SHIFT] = _shift_cost(&move, s, &gold.c) + move_costs[REDUCE] = _reduce_cost(&move, s, &gold.c) + move_costs[LEFT] = _left_cost(&move, s, &gold.c) + move_costs[RIGHT] = _right_cost(&move, s, &gold.c) + move_costs[BREAK] = _break_cost(&move, s, &gold.c) + move_costs[CONSTITUENT] = _constituent_cost(&move, s, &gold.c) + move_costs[ADJUST] = _adjust_cost(&move, s, &gold.c) + + cdef int i, label + cdef int* labels = gold.c.labels + cdef int* heads = gold.c.heads + for i in range(self.n_moves): + move = self.c[i] + output[i] = move_costs[move.move] + if output[i] == 0: + label = -1 + if move.move == RIGHT and heads[s.i] == s.stack[0]: + label = labels[s.i] + if move.move == LEFT and heads[s.stack[0]] == s.i: + label = labels[s.stack[0]] + elif move.move == LEFT and at_eol(s) and (_can_reduce(s) or _can_break(s)): + label = labels[s.stack[0]] + output[i] += move.label != label and label != -1 + + cdef Transition best_valid(self, const weight_t* scores, const State* s) except *: cdef bint[N_MOVES] is_valid is_valid[SHIFT] = _can_shift(s) @@ -234,7 +263,7 @@ cdef int _right_cost(const Transition* self, const State* s, GoldParseC* gold) e return 9000 cost = 0 if gold.heads[s.i] == s.stack[0]: - cost += self.label != gold.labels[s.i] + cost += self.label != -1 and self.label != gold.labels[s.i] return cost # This indicates missing head if gold.labels[s.i] != -1: @@ -249,7 +278,7 @@ cdef int _left_cost(const Transition* self, const State* s, GoldParseC* gold) ex return 9000 cost = 0 if gold.heads[s.stack[0]] == s.i: - cost += self.label != gold.labels[s.stack[0]] + cost += self.label != -1 and self.label != gold.labels[s.stack[0]] return cost # If we're at EOL, then the left arc will add an arc to ROOT. elif at_eol(s): @@ -259,7 +288,7 @@ cdef int _left_cost(const Transition* self, const State* s, GoldParseC* gold) ex if _can_reduce(s) or _can_break(s): cost += gold.heads[s.stack[0]] != s.stack[0] # Are we labelling correctly? - cost += self.label != gold.labels[s.stack[0]] + cost += self.label != -1 and self.label != gold.labels[s.stack[0]] return cost cost += head_in_buffer(s, s.stack[0], gold.heads) diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index 6114c8a0a..3772ea0f1 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -153,10 +153,12 @@ cdef class Parser: self._advance_beam(pred, gold_parse, False) self._advance_beam(gold, gold_parse, True) violn.check(pred, gold) - counts = {} if pred.loss >= 1: + counts = {clas: {} for clas in range(self.model.n_classes)} self._count_feats(counts, tokens, violn.g_hist, 1) self._count_feats(counts, tokens, violn.p_hist, -1) + else: + counts = {} self.model._model.update(counts) return pred.loss @@ -171,20 +173,14 @@ cdef class Parser: fill_context(context, state) self.model.set_scores(beam.scores[i], context) self.moves.set_valid(beam.is_valid[i], state) - - if follow_gold: + + if gold is not None: for i in range(beam.size): state = beam.at(i) - for j in range(self.moves.n_moves): - move = &self.moves.c[j] - beam.costs[i][j] = move.get_cost(move, state, &gold.c) - beam.is_valid[i][j] = beam.costs[i][j] == 0 - elif gold is not None: - for i in range(beam.size): - state = beam.at(i) - for j in range(self.moves.n_moves): - move = &self.moves.c[j] - beam.costs[i][j] = move.get_cost(move, state, &gold.c) + self.moves.set_costs(beam.costs[i], state, gold) + if follow_gold: + for j in range(self.moves.n_moves): + beam.is_valid[i][j] = beam.costs[i][j] == 0 beam.advance(_transition_state, self.moves.c) state = beam.at(0) if state.sent[state.i].sent_end: @@ -204,7 +200,7 @@ cdef class Parser: break fill_context(context, state) feats = self.model._extractor.get_feats(context, &n_feats) - count_feats(counts.setdefault(clas, {}), feats, n_feats, inc) + count_feats(counts[clas], feats, n_feats, inc) self.moves.c[clas].do(&self.moves.c[clas], state)