Fix feature-flagging of Split action

This commit is contained in:
Matthew Honnibal 2018-04-03 15:43:50 +02:00
parent 6cc79fc244
commit 4029dc2cc7

View File

@ -19,14 +19,11 @@ from ..structs cimport TokenC
# Calculate cost as gold/not gold. We don't use scalar value anyway.
cdef int BINARY_COSTS = 1
cdef int MAX_SPLIT = 4
DEF NON_MONOTONIC = True
DEF USE_BREAK = True
DEF USE_SPLIT = False
cdef weight_t MIN_SCORE = -90000
# Sets NON_MONOTONIC, USE_BREAK, USE_SPLIT, MAX_SPLIT
include "compile_time.pxi"
# Break transition inspired by this paper:
# http://www.aclweb.org/anthology/P13-1074
# The most relevant factor is whether we predict Break early, or late:
@ -72,19 +69,22 @@ cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) no
cost += Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0
# If the token wasn't split before, but gold says it *should* be split,
# don't push (split instead)
if USE_SPLIT and not stcls.c.was_split[stcls.c.B(0)]:
cost += gold.fused[stcls.c.B(0)]
if USE_SPLIT:
if not stcls.c.was_split[stcls.c.B(0)] and gold.fused[stcls.c.B(0)]:
cost += 1
return cost
cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
cdef weight_t cost = 0
cdef int i, B_i
# Take into account fused tokens
cdef int target_token = target % stcls.c.length
for i in range(stcls.c.segment_length()):
B_i = stcls.B(i)
cost += gold.heads[B_i] == target
cost += gold.heads[target] == B_i
if gold.heads[B_i] == B_i or gold.heads[B_i] < target:
if gold.heads[B_i] == B_i or (gold.heads[B_i]%stcls.c.length) < target:
break
if BINARY_COSTS and cost >= 1:
return cost
@ -97,7 +97,7 @@ cdef weight_t arc_cost(StateClass stcls, const GoldParseC* gold, int head, int c
elif stcls.H(child) == gold.heads[child]:
return 1
# Head in buffer
elif gold.heads[child] >= stcls.B(0) and stcls.B(1) != 0:
elif gold.heads[child] >= (stcls.B(0) % stcls.c.length) and stcls.B(1) != 0:
return 1
else:
return 0
@ -171,7 +171,7 @@ cdef class Split:
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
st.split(0, label)
st.split(0, 1)
@staticmethod
cdef weight_t cost(StateClass st, const GoldParseC* gold, attr_t label) nogil:
@ -179,14 +179,18 @@ cdef class Split:
@staticmethod
cdef weight_t move_cost(StateClass st, const GoldParseC* gold) nogil:
if gold.fused[st.B(0)]:
if not USE_SPLIT:
return 9000
elif gold.fused[st.B(0)]:
return 0
else:
return 1
@staticmethod
cdef weight_t label_cost(StateClass st, const GoldParseC* gold, attr_t label) nogil:
if gold.fused[st.B(0)] == label:
if not USE_SPLIT:
return 9000
elif gold.fused[st.B(0)] == 1: #label:
return 0
else:
return 1
@ -305,7 +309,7 @@ cdef class RightArc:
cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
# If the token wasn't split before, but gold says it *should* be split,
# don't right-arc (split instead)
if not s.c.was_split[s.c.B(0)] and gold.fused[s.c.B(0)]:
if USE_SPLIT and not s.c.was_split[s.c.B(0)] and gold.fused[s.c.B(0)]:
return gold.fused[s.c.B(0)]
elif arc_is_gold(gold, s.S(0), s.B(0)):
return 0
@ -438,16 +442,16 @@ cdef class ArcEager(TransitionSystem):
# TODO: Split?
return actions
#property max_split:
# def __get__(self):
# return self.cfg.get('max_split', 0)
# def __set__(self, int value):
# self.cfg['max_split'] = value
property max_split:
def __get__(self):
return MAX_SPLIT
property action_types:
def __get__(self):
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK, SPLIT)
if USE_SPLIT:
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK, SPLIT)
else:
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
def get_cost(self, StateClass state, GoldParse gold, action):
cdef Transition t = self.lookup_transition(action)
@ -493,7 +497,7 @@ cdef class ArcEager(TransitionSystem):
# Subtokens are addressed by (subposition, position).
# This way the 'normal' tokens (at subposition 0) occupy positions
# 0...n in the array.
for i in range(1, MAX_SPLIT-1):
for i in range(1, MAX_SPLIT):
for j in range(len(gold)):
index = i * len(gold) + j
# If we've incorrectly split, we want to join them back
@ -507,6 +511,11 @@ cdef class ArcEager(TransitionSystem):
gold.c.heads[index] = index
gold.c.labels[index] = 0
gold.c.has_dep[index] = False
for i in range(len(gold)):
if isinstance(gold.heads[i], list):
gold.c.fused[i] = len(gold.heads)-1
else:
gold.c.fused[i] = 0
for child_i, (head_group, dep_group) in enumerate(zip(gold.heads, gold.labels)):
if not USE_SPLIT and (isinstance(head_group, list) or isinstance(head_group, tuple)):
# Set as missing values if we don't handle token splitting
@ -521,6 +530,9 @@ cdef class ArcEager(TransitionSystem):
if not isinstance(head_addr, tuple):
head_addr = (head_addr, 0)
head_i, head_j = head_addr
if not USE_SPLIT:
head_j = 0
child_j = 0
child_index = child_j * len(gold) + child_i
# Missing values
if head_i is None or dep is None: