* Refactor transition system oracles, to split out move and label cost. Preparing to add Unshift move. Will exclude non-monotonic.

This commit is contained in:
Matthew Honnibal 2015-06-07 03:21:29 +02:00
parent e2578fbb90
commit 8f142c1838
2 changed files with 141 additions and 43 deletions

View File

@ -11,6 +11,7 @@ from ._state cimport count_left_kids
from ..structs cimport TokenC
from .transition_system cimport do_func_t, get_cost_func_t
from .transition_system cimport move_cost_func_t, label_cost_func_t
from ..gold cimport GoldParse
from ..gold cimport GoldParseC
@ -46,6 +47,35 @@ MOVE_NAMES[CONSTITUENT] = 'C'
MOVE_NAMES[ADJUST] = 'A'
# Helper functions for the arc-eager oracle
cdef int push_cost(const State* st, const GoldParseC* gold, int target) except -1:
# When we push a word, we can't make arcs to or from the stack. So, we lose
# any of those arcs.
cdef int cost = 0
cost += head_in_stack(st, target, gold.heads)
cost += children_in_stack(st, target, gold.heads)
return cost
cdef int pop_cost(const State* st, const GoldParseC* gold, int target) except -1:
cdef int cost = 0
cost += children_in_buffer(st, target, gold.heads)
cost += head_in_buffer(st, target, gold.heads)
return cost
cdef int arc_cost(const GoldParseC* gold, int head, int child, int label) except -1:
if gold.heads[child] != head:
return 0
elif gold.labels[child] == -1:
return 0
elif gold.labels[child] == label:
return 0
else:
return 1
cdef class Shift:
@staticmethod
cdef bint is_valid(const State* s, int label) except -1:
@ -62,14 +92,20 @@ cdef class Shift:
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not Shift.is_valid(s, label):
return 9000
cost = 0
cost += head_in_stack(s, s.i, gold.heads)
cost += children_in_stack(s, s.i, gold.heads)
return Shift.move_cost(s, gold) + Shift.label_cost(s, gold, label)
@staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
cdef int cost = push_cost(s, gold, s.i)
# If we can break, and there's no cost to doing so, we should
if Break.is_valid(s, label) and Break.cost(s, gold, -1) == 0:
if Break.is_valid(s, -1) and Break.cost(s, gold, -1) == 0:
cost += 1
return cost
@staticmethod
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
return 0
cdef class Reduce:
@staticmethod
@ -89,11 +125,19 @@ cdef class Reduce:
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not Reduce.is_valid(s, label):
return 9000
cdef int cost = 0
cost += children_in_buffer(s, s.stack[0], gold.heads)
return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label)
@staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
if NON_MONOTONIC:
cost += head_in_buffer(s, s.stack[0], gold.heads)
return cost
return pop_cost(s, gold, s.stack[0])
else:
return children_in_buffer(s, s.stack[0], gold.heads)
@staticmethod
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
return 0
cdef class LeftArc:
@ -117,19 +161,21 @@ cdef class LeftArc:
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not LeftArc.is_valid(s, label):
return 9000
return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label)
@staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
if not LeftArc.is_valid(s, -1):
return 9000
cdef int cost = 0
if gold.heads[s.stack[0]] == s.i:
cost += label != -1 and label != gold.labels[s.stack[0]]
return cost
# If we're at EOL, then the left arc will add an arc to ROOT.
elif at_eol(s):
# Are we root?
if gold.labels[s.stack[0]] != -1:
# If we're at EOL, prefer to reduce or break over left-arc
if Reduce.is_valid(s, -1) or Break.is_valid(s, -1):
cost += gold.heads[s.stack[0]] != s.stack[0]
# Are we labelling correctly?
cost += label != -1 and label != gold.labels[s.stack[0]]
return cost
cost += head_in_buffer(s, s.stack[0], gold.heads)
cost += children_in_buffer(s, s.stack[0], gold.heads)
@ -139,6 +185,14 @@ cdef class LeftArc:
cost += gold.heads[s.stack[0]] == s.stack[0]
return cost
@staticmethod
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
if label == -1 or gold.labels[s.stack[0]] == -1:
return 0
if gold.heads[s.stack[0]] == s.i and label != gold.labels[s.stack[0]]:
return 1
return 0
cdef class RightArc:
@staticmethod
@ -154,17 +208,25 @@ cdef class RightArc:
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not RightArc.is_valid(s, label):
return 9000
cdef int cost
cost = 0
if gold.heads[s.i] == s.stack[0]:
cost += label != -1 and label != gold.labels[s.i]
return cost
return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
@staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
return push_cost(s, gold, s.i) - (gold.heads[s.i] == s.stack[0])
@staticmethod
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
return arc_cost(gold, s.stack[0], s.i, label)
#cdef int cost = 0
#if gold.heads[s.i] == s.stack[0]:
# cost += label != -1 and label != gold.labels[s.i]
# return cost
# This indicates missing head
if gold.labels[s.i] != -1:
cost += head_in_buffer(s, s.i, gold.heads)
cost += children_in_stack(s, s.i, gold.heads)
cost += head_in_stack(s, s.i, gold.heads)
return cost
#if gold.labels[s.i] != -1:
# cost += head_in_buffer(s, s.i, gold.heads)
#cost += children_in_stack(s, s.i, gold.heads)
#cost += head_in_stack(s, s.i, gold.heads)
#return cost
cdef class Break:
@ -207,6 +269,11 @@ cdef class Break:
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not Break.is_valid(s, label):
return 9000
else:
return Break.move_cost(s, gold) + Break.label_cost(s, gold, label)
@staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
# When we break, we Reduce all of the words on the stack.
cdef int cost = 0
# Number of deps between S0...Sn and N0...Nn
@ -214,6 +281,10 @@ cdef class Break:
cost += children_in_stack(s, i, gold.heads)
cost += head_in_stack(s, i, gold.heads)
return cost
@staticmethod
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
return 0
cdef class Constituent:
@ -280,6 +351,17 @@ cdef class Constituent:
# loss = 1 # If we see the start position, set loss to 1
#return loss
@staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
if not Constituent.is_valid(s, -1):
return 9000
raise Exception("Constituent move should be disabled currently")
@staticmethod
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
return 0
cdef class Adjust:
@staticmethod
@ -318,6 +400,16 @@ cdef class Adjust:
if not Adjust.is_valid(s, label):
return 9000
raise Exception("Adjust move should be disabled currently")
@staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
if not Adjust.is_valid(s, -1):
return 9000
raise Exception("Adjust move should be disabled currently")
@staticmethod
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
return 0
# The gold standard is indexed by end, then by start, then a set of labels
#gold_starts = gold.brackets(get_s0(s).r_edge, {})
# Case 1: There are 0 brackets ending at this word.
@ -460,32 +552,36 @@ cdef class ArcEager(TransitionSystem):
output[i] = is_valid[self.c[i].move]
cdef int set_costs(self, int* output, const State* s, GoldParse gold) except -1:
cdef Transition move
move.label = -1
cdef int i, move, label
cdef label_cost_func_t[N_MOVES] label_cost_funcs
cdef move_cost_func_t[N_MOVES] move_cost_funcs
cdef int[N_MOVES] move_costs
move_costs[SHIFT] = Shift.cost(s, &gold.c, -1)
move_costs[REDUCE] = Reduce.cost(s, &gold.c, -1)
move_costs[LEFT] = LeftArc.cost(s, &gold.c, -1)
move_costs[RIGHT] = RightArc.cost(s, &gold.c, -1)
move_costs[BREAK] = Break.cost(s, &gold.c, -1)
move_costs[CONSTITUENT] = Constituent.cost(s, &gold.c, -1)
move_costs[ADJUST] = Adjust.cost(s, &gold.c, -1)
for i in range(N_MOVES):
move_costs[i] = -1
move_cost_funcs[SHIFT] = Shift.move_cost
move_cost_funcs[REDUCE] = Reduce.move_cost
move_cost_funcs[LEFT] = LeftArc.move_cost
move_cost_funcs[RIGHT] = RightArc.move_cost
move_cost_funcs[BREAK] = Break.move_cost
move_cost_funcs[CONSTITUENT] = Constituent.move_cost
move_cost_funcs[ADJUST] = Adjust.move_cost
label_cost_funcs[SHIFT] = Shift.label_cost
label_cost_funcs[REDUCE] = Reduce.label_cost
label_cost_funcs[LEFT] = LeftArc.label_cost
label_cost_funcs[RIGHT] = RightArc.label_cost
label_cost_funcs[BREAK] = Break.label_cost
label_cost_funcs[CONSTITUENT] = Constituent.label_cost
label_cost_funcs[ADJUST] = Adjust.label_cost
cdef int i, label
cdef int* labels = gold.c.labels
cdef int* heads = gold.c.heads
for i in range(self.n_moves):
move = self.c[i]
output[i] = move_costs[move.move]
if output[i] == 0:
label = -1
if move.move == RIGHT and heads[s.i] == s.stack[0]:
label = labels[s.i]
if move.move == LEFT and heads[s.stack[0]] == s.i:
label = labels[s.stack[0]]
elif move.move == LEFT and at_eol(s) and (Reduce.is_valid(s, -1) or Break.is_valid(s, 1)):
label = labels[s.stack[0]]
output[i] += move.label != label and label != -1
move = self.c[i].move
label = self.c[i].label
if move_costs[move] == -1:
move_costs[move] = move_cost_funcs[move](s, &gold.c)
output[i] = move_costs[move] + label_cost_funcs[move](s, &gold.c, label)
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
cdef bint[N_MOVES] is_valid

View File

@ -21,6 +21,8 @@ cdef struct Transition:
ctypedef int (*get_cost_func_t)(const State* state, const GoldParseC* gold, int label) except -1
ctypedef int (*move_cost_func_t)(const State* state, const GoldParseC* gold) except -1
ctypedef int (*label_cost_func_t)(const State* state, const GoldParseC* gold, int label) except -1
ctypedef int (*do_func_t)(State* state, int label) except -1