Draft ArcEager.preprocess_gold for fused tokens

This commit is contained in:
Matthew Honnibal 2018-04-01 22:11:35 +02:00
parent fb9c3984b5
commit 9c3612d40b

View File

@ -28,10 +28,11 @@ cdef weight_t MIN_SCORE = -90000
# Break transition inspired by this paper:
# http://www.aclweb.org/anthology/P13-1074
# However, there's a significant difference in the constraints.
# The most relevant factor is whether we predict Break early, or late:
# do we wait until the root is on the stack, or do we predict when the last
# word of the previous sentence is on the stack?
# The paper applies Break early. This makes life harder, but we find it's
# worth it to give the model flexibility, and Break when stack may be deep.
cdef enum:
SHIFT
REDUCE
@ -162,7 +163,7 @@ cdef class Split:
return 0
elif st.buffer_length == 0:
return 0
elif st.is_split[st.B(0)]:
elif st.was_split[st.B(0)]:
return 0
else:
return 1
@ -250,6 +251,7 @@ cdef class LeftArc:
@staticmethod
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
# TODO: Handle oracle for incorrect splits
cdef weight_t move_cost = LeftArc.move_cost(s, gold)
cdef weight_t label_cost = LeftArc.label_cost(s, gold, label)
return move_cost + label_cost
@ -295,6 +297,7 @@ cdef class RightArc:
@staticmethod
cdef inline weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
# TODO: Handle oracle for incorrect splits
return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
@staticmethod
@ -431,8 +434,16 @@ cdef class ArcEager(TransitionSystem):
# Used for backoff
actions[RIGHT].setdefault('dep', 0)
actions[LEFT].setdefault('dep', 0)
# TODO: Split?
return actions
property max_split:
def __get__(self):
return self.cfg.get('max_split', 0)
def __set__(self, int value):
self.cfg['max_split'] = value
property action_types:
def __get__(self):
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK, SPLIT)
@ -474,28 +485,50 @@ cdef class ArcEager(TransitionSystem):
def preprocess_gold(self, GoldParse gold):
if not self.has_gold(gold):
return None
for i, (head_group, dep_group) in enumerate(zip(gold.heads, gold.labels)):
if not USE_SPLIT:
if isinstance(head_group, list):
head_group = [(None, 0)]
dep_group = [None]
# Missing values
subtok_label = self.strings['subtok']
if USE_SPLIT:
gold.resize_arrays(self.max_split * len(gold))
# Subtokens are addressed by (subposition, position).
# This way the 'normal' tokens (at subposition 0) occupy positions
# 0...n in the array.
for i in range(1, self.max_split-1):
for j in range(len(gold)):
index = i * len(gold) + j
# If we've incorrectly split, we want to join them back
# up -- so, set the head of each subtoken to the following
# subtoken (until the end), and set the label to 'subtok'.
gold.c.heads[index] = (i+1)*len(gold) + j
gold.c.dep[index] = subtok_label
gold.c.has_dep[index] = True
for j in range(len(gold)):
# For the last subtoken in each position, set head to 'unknown'.
gold.c.heads[index] = index
gold.c.deps[index] = 0
gold.c.has_dep[index] = False
for child_i, (head_group, dep_group) in enumerate(zip(gold.heads, gold.labels)):
if not USE_SPLIT and isinstance(head_group, list):
# Set as missing values if we don't handle token splitting
head_group = [(None, 0)]
dep_group = [None]
if not isinstance(head_group, list):
# Map the simple format into the elaborate one we need for
# the fused tokens.
head_group = [(head_group, 0)]
dep_group = [dep_group]
for head_addr, dep in zip(head_group, dep_group):
for child_j, (head_addr, dep) in enumerate(zip(head_group, dep_group)):
if not isinstance(head_addr, tuple):
head_addr = (head_addr, 0)
head, subtoken = head_addr
if head is None or dep is None:
gold.c.heads[i] = i
gold.c.has_dep[i] = False
head_i, head_j = head_addr
child_index = child_j * len(gold) + child_i
# Missing values
if head_i is None or dep is None:
gold.c.heads[child_index] = child_index
gold.c.has_dep[child_index] = False
continue
if head > i:
head_index = head_j * len(gold) + head_i
if (head_i, head_j) > (child_i, child_j):
action = LEFT
elif head < i:
elif (head_i, head_j) < (child_i, child_j):
action = RIGHT
else:
action = BREAK
@ -510,11 +543,11 @@ cdef class ArcEager(TransitionSystem):
dep = 'dep'
else:
dep = 'dep'
gold.c.has_dep[i] = True
gold.c.has_dep[child_index] = True
if dep.upper() == 'ROOT':
dep = 'ROOT'
gold.c.heads[i] = head
gold.c.labels[i] = self.strings.add(dep)
gold.c.heads[child_index] = head_index
gold.c.labels[child_index] = self.strings.add(dep)
return gold
def get_beam_parses(self, Beam beam):