diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 02a9e7d30..855535f4e 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -11,6 +11,7 @@ from ._state cimport count_left_kids from ..structs cimport TokenC from .transition_system cimport do_func_t, get_cost_func_t +from .transition_system cimport move_cost_func_t, label_cost_func_t from ..gold cimport GoldParse from ..gold cimport GoldParseC @@ -46,6 +47,35 @@ MOVE_NAMES[CONSTITUENT] = 'C' MOVE_NAMES[ADJUST] = 'A' +# Helper functions for the arc-eager oracle + +cdef int push_cost(const State* st, const GoldParseC* gold, int target) except -1: + # When we push a word, we can't make arcs to or from the stack. So, we lose + # any of those arcs. + cdef int cost = 0 + cost += head_in_stack(st, target, gold.heads) + cost += children_in_stack(st, target, gold.heads) + return cost + + +cdef int pop_cost(const State* st, const GoldParseC* gold, int target) except -1: + cdef int cost = 0 + cost += children_in_buffer(st, target, gold.heads) + cost += head_in_buffer(st, target, gold.heads) + return cost + + +cdef int arc_cost(const GoldParseC* gold, int head, int child, int label) except -1: + if gold.heads[child] != head: + return 0 + elif gold.labels[child] == -1: + return 0 + elif gold.labels[child] == label: + return 0 + else: + return 1 + + cdef class Shift: @staticmethod cdef bint is_valid(const State* s, int label) except -1: @@ -62,14 +92,20 @@ cdef class Shift: cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: if not Shift.is_valid(s, label): return 9000 - cost = 0 - cost += head_in_stack(s, s.i, gold.heads) - cost += children_in_stack(s, s.i, gold.heads) + return Shift.move_cost(s, gold) + Shift.label_cost(s, gold, label) + + @staticmethod + cdef int move_cost(const State* s, const GoldParseC* gold) except -1: + cdef int cost = push_cost(s, gold, s.i) # If we can break, and there's no cost to doing so, we should - if Break.is_valid(s, label) and Break.cost(s, gold, -1) == 0: + if Break.is_valid(s, -1) and Break.cost(s, gold, -1) == 0: cost += 1 return cost + @staticmethod + cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1: + return 0 + cdef class Reduce: @staticmethod @@ -89,11 +125,19 @@ cdef class Reduce: cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: if not Reduce.is_valid(s, label): return 9000 - cdef int cost = 0 - cost += children_in_buffer(s, s.stack[0], gold.heads) + return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label) + + @staticmethod + cdef int move_cost(const State* s, const GoldParseC* gold) except -1: if NON_MONOTONIC: - cost += head_in_buffer(s, s.stack[0], gold.heads) - return cost + return pop_cost(s, gold, s.stack[0]) + else: + return children_in_buffer(s, s.stack[0], gold.heads) + + @staticmethod + cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1: + return 0 + cdef class LeftArc: @@ -117,19 +161,21 @@ cdef class LeftArc: cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: if not LeftArc.is_valid(s, label): return 9000 + return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label) + + @staticmethod + cdef int move_cost(const State* s, const GoldParseC* gold) except -1: + if not LeftArc.is_valid(s, -1): + return 9000 cdef int cost = 0 if gold.heads[s.stack[0]] == s.i: - cost += label != -1 and label != gold.labels[s.stack[0]] return cost - # If we're at EOL, then the left arc will add an arc to ROOT. elif at_eol(s): # Are we root? if gold.labels[s.stack[0]] != -1: # If we're at EOL, prefer to reduce or break over left-arc if Reduce.is_valid(s, -1) or Break.is_valid(s, -1): cost += gold.heads[s.stack[0]] != s.stack[0] - # Are we labelling correctly? - cost += label != -1 and label != gold.labels[s.stack[0]] return cost cost += head_in_buffer(s, s.stack[0], gold.heads) cost += children_in_buffer(s, s.stack[0], gold.heads) @@ -139,6 +185,14 @@ cdef class LeftArc: cost += gold.heads[s.stack[0]] == s.stack[0] return cost + @staticmethod + cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1: + if label == -1 or gold.labels[s.stack[0]] == -1: + return 0 + if gold.heads[s.stack[0]] == s.i and label != gold.labels[s.stack[0]]: + return 1 + return 0 + cdef class RightArc: @staticmethod @@ -154,17 +208,25 @@ cdef class RightArc: cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: if not RightArc.is_valid(s, label): return 9000 - cdef int cost - cost = 0 - if gold.heads[s.i] == s.stack[0]: - cost += label != -1 and label != gold.labels[s.i] - return cost + return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label) + + @staticmethod + cdef int move_cost(const State* s, const GoldParseC* gold) except -1: + return push_cost(s, gold, s.i) - (gold.heads[s.i] == s.stack[0]) + + @staticmethod + cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1: + return arc_cost(gold, s.stack[0], s.i, label) + #cdef int cost = 0 + #if gold.heads[s.i] == s.stack[0]: + # cost += label != -1 and label != gold.labels[s.i] + # return cost # This indicates missing head - if gold.labels[s.i] != -1: - cost += head_in_buffer(s, s.i, gold.heads) - cost += children_in_stack(s, s.i, gold.heads) - cost += head_in_stack(s, s.i, gold.heads) - return cost + #if gold.labels[s.i] != -1: + # cost += head_in_buffer(s, s.i, gold.heads) + #cost += children_in_stack(s, s.i, gold.heads) + #cost += head_in_stack(s, s.i, gold.heads) + #return cost cdef class Break: @@ -207,6 +269,11 @@ cdef class Break: cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: if not Break.is_valid(s, label): return 9000 + else: + return Break.move_cost(s, gold) + Break.label_cost(s, gold, label) + + @staticmethod + cdef int move_cost(const State* s, const GoldParseC* gold) except -1: # When we break, we Reduce all of the words on the stack. cdef int cost = 0 # Number of deps between S0...Sn and N0...Nn @@ -214,6 +281,10 @@ cdef class Break: cost += children_in_stack(s, i, gold.heads) cost += head_in_stack(s, i, gold.heads) return cost + + @staticmethod + cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1: + return 0 cdef class Constituent: @@ -280,6 +351,17 @@ cdef class Constituent: # loss = 1 # If we see the start position, set loss to 1 #return loss + @staticmethod + cdef int move_cost(const State* s, const GoldParseC* gold) except -1: + if not Constituent.is_valid(s, -1): + return 9000 + raise Exception("Constituent move should be disabled currently") + + @staticmethod + cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1: + return 0 + + cdef class Adjust: @staticmethod @@ -318,6 +400,16 @@ cdef class Adjust: if not Adjust.is_valid(s, label): return 9000 raise Exception("Adjust move should be disabled currently") + + @staticmethod + cdef int move_cost(const State* s, const GoldParseC* gold) except -1: + if not Adjust.is_valid(s, -1): + return 9000 + raise Exception("Adjust move should be disabled currently") + + @staticmethod + cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1: + return 0 # The gold standard is indexed by end, then by start, then a set of labels #gold_starts = gold.brackets(get_s0(s).r_edge, {}) # Case 1: There are 0 brackets ending at this word. @@ -460,32 +552,36 @@ cdef class ArcEager(TransitionSystem): output[i] = is_valid[self.c[i].move] cdef int set_costs(self, int* output, const State* s, GoldParse gold) except -1: - cdef Transition move - move.label = -1 + cdef int i, move, label + cdef label_cost_func_t[N_MOVES] label_cost_funcs + cdef move_cost_func_t[N_MOVES] move_cost_funcs cdef int[N_MOVES] move_costs - move_costs[SHIFT] = Shift.cost(s, &gold.c, -1) - move_costs[REDUCE] = Reduce.cost(s, &gold.c, -1) - move_costs[LEFT] = LeftArc.cost(s, &gold.c, -1) - move_costs[RIGHT] = RightArc.cost(s, &gold.c, -1) - move_costs[BREAK] = Break.cost(s, &gold.c, -1) - move_costs[CONSTITUENT] = Constituent.cost(s, &gold.c, -1) - move_costs[ADJUST] = Adjust.cost(s, &gold.c, -1) + for i in range(N_MOVES): + move_costs[i] = -1 + move_cost_funcs[SHIFT] = Shift.move_cost + move_cost_funcs[REDUCE] = Reduce.move_cost + move_cost_funcs[LEFT] = LeftArc.move_cost + move_cost_funcs[RIGHT] = RightArc.move_cost + move_cost_funcs[BREAK] = Break.move_cost + move_cost_funcs[CONSTITUENT] = Constituent.move_cost + move_cost_funcs[ADJUST] = Adjust.move_cost + + label_cost_funcs[SHIFT] = Shift.label_cost + label_cost_funcs[REDUCE] = Reduce.label_cost + label_cost_funcs[LEFT] = LeftArc.label_cost + label_cost_funcs[RIGHT] = RightArc.label_cost + label_cost_funcs[BREAK] = Break.label_cost + label_cost_funcs[CONSTITUENT] = Constituent.label_cost + label_cost_funcs[ADJUST] = Adjust.label_cost - cdef int i, label cdef int* labels = gold.c.labels cdef int* heads = gold.c.heads for i in range(self.n_moves): - move = self.c[i] - output[i] = move_costs[move.move] - if output[i] == 0: - label = -1 - if move.move == RIGHT and heads[s.i] == s.stack[0]: - label = labels[s.i] - if move.move == LEFT and heads[s.stack[0]] == s.i: - label = labels[s.stack[0]] - elif move.move == LEFT and at_eol(s) and (Reduce.is_valid(s, -1) or Break.is_valid(s, 1)): - label = labels[s.stack[0]] - output[i] += move.label != label and label != -1 + move = self.c[i].move + label = self.c[i].label + if move_costs[move] == -1: + move_costs[move] = move_cost_funcs[move](s, &gold.c) + output[i] = move_costs[move] + label_cost_funcs[move](s, &gold.c, label) cdef Transition best_valid(self, const weight_t* scores, const State* s) except *: cdef bint[N_MOVES] is_valid diff --git a/spacy/syntax/transition_system.pxd b/spacy/syntax/transition_system.pxd index 99017c306..584e361df 100644 --- a/spacy/syntax/transition_system.pxd +++ b/spacy/syntax/transition_system.pxd @@ -21,6 +21,8 @@ cdef struct Transition: ctypedef int (*get_cost_func_t)(const State* state, const GoldParseC* gold, int label) except -1 +ctypedef int (*move_cost_func_t)(const State* state, const GoldParseC* gold) except -1 +ctypedef int (*label_cost_func_t)(const State* state, const GoldParseC* gold, int label) except -1 ctypedef int (*do_func_t)(State* state, int label) except -1