Record negative costs in non-monotonic arc eager oracle

This commit is contained in:
Matthew Honnibal 2017-03-10 11:22:04 -06:00
parent ecf91a2dbb
commit d11f1a4ddf

View File

@ -1,3 +1,6 @@
# cython: profile=True
# cython: cdivision=True
# cython: infer_types=True
from __future__ import unicode_literals
import ctypes
@ -155,7 +158,18 @@ cdef class Reduce:
@staticmethod
cdef inline weight_t move_cost(StateClass st, const GoldParseC* gold) nogil:
return pop_cost(st, gold, st.S(0))
cost = pop_cost(st, gold, st.S(0))
if not st.has_head(st.S(0)):
# Decrement cost for the arcs e save
for i in range(1, st.stack_depth()):
S_i = st.S(i)
if gold.heads[st.S(0)] == S_i:
cost -= 1
if gold.heads[S_i] == st.S(0):
cost -= 1
if Break.is_valid(st.c, -1) and Break.move_cost(st, gold) == 0:
cost -= 1
return cost
@staticmethod
cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
@ -181,7 +195,8 @@ cdef class LeftArc:
cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
cdef weight_t cost = 0
if arc_is_gold(gold, s.B(0), s.S(0)):
return 0
# Have a negative cost if we 'recover' from the wrong dependency
return 0 if not s.has_head(s.S(0)) else -1
else:
# Account for deps we might lose between S0 and stack
if not s.has_head(s.S(0)):
@ -414,7 +429,7 @@ cdef class ArcEager(TransitionSystem):
cdef move_cost_func_t[N_MOVES] move_cost_funcs
cdef weight_t[N_MOVES] move_costs
for i in range(N_MOVES):
move_costs[i] = -1
move_costs[i] = 9000
move_cost_funcs[SHIFT] = Shift.move_cost
move_cost_funcs[REDUCE] = Reduce.move_cost
move_cost_funcs[LEFT] = LeftArc.move_cost
@ -436,14 +451,14 @@ cdef class ArcEager(TransitionSystem):
is_valid[i] = True
move = self.c[i].move
label = self.c[i].label
if move_costs[move] == -1:
if move_costs[move] == 9000:
move_costs[move] = move_cost_funcs[move](stcls, &gold.c)
costs[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label)
n_gold += costs[i] <= 0
else:
is_valid[i] = False
costs[i] = 9000
if n_gold == 0:
if n_gold < 1:
# Check projectivity --- leading cause
if is_nonproj_tree(gold.heads):
raise ValueError(