mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-10 15:14:56 +03:00
Draft ArcEager.preprocess_gold for fused tokens
This commit is contained in:
parent
fb9c3984b5
commit
9c3612d40b
|
@ -28,10 +28,11 @@ cdef weight_t MIN_SCORE = -90000
|
|||
|
||||
# Break transition inspired by this paper:
|
||||
# http://www.aclweb.org/anthology/P13-1074
|
||||
# However, there's a significant difference in the constraints.
|
||||
# The most relevant factor is whether we predict Break early, or late:
|
||||
# do we wait until the root is on the stack, or do we predict when the last
|
||||
# word of the previous sentence is on the stack?
|
||||
# The paper applies Break early. This makes life harder, but we find it's
|
||||
# worth it to give the model flexibility, and Break when stack may be deep.
|
||||
cdef enum:
|
||||
SHIFT
|
||||
REDUCE
|
||||
|
@ -162,7 +163,7 @@ cdef class Split:
|
|||
return 0
|
||||
elif st.buffer_length == 0:
|
||||
return 0
|
||||
elif st.is_split[st.B(0)]:
|
||||
elif st.was_split[st.B(0)]:
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
|
@ -250,6 +251,7 @@ cdef class LeftArc:
|
|||
|
||||
@staticmethod
|
||||
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
||||
# TODO: Handle oracle for incorrect splits
|
||||
cdef weight_t move_cost = LeftArc.move_cost(s, gold)
|
||||
cdef weight_t label_cost = LeftArc.label_cost(s, gold, label)
|
||||
return move_cost + label_cost
|
||||
|
@ -295,6 +297,7 @@ cdef class RightArc:
|
|||
|
||||
@staticmethod
|
||||
cdef inline weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
||||
# TODO: Handle oracle for incorrect splits
|
||||
return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
|
||||
|
||||
@staticmethod
|
||||
|
@ -431,8 +434,16 @@ cdef class ArcEager(TransitionSystem):
|
|||
# Used for backoff
|
||||
actions[RIGHT].setdefault('dep', 0)
|
||||
actions[LEFT].setdefault('dep', 0)
|
||||
# TODO: Split?
|
||||
return actions
|
||||
|
||||
property max_split:
|
||||
def __get__(self):
|
||||
return self.cfg.get('max_split', 0)
|
||||
|
||||
def __set__(self, int value):
|
||||
self.cfg['max_split'] = value
|
||||
|
||||
property action_types:
|
||||
def __get__(self):
|
||||
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK, SPLIT)
|
||||
|
@ -474,28 +485,50 @@ cdef class ArcEager(TransitionSystem):
|
|||
def preprocess_gold(self, GoldParse gold):
|
||||
if not self.has_gold(gold):
|
||||
return None
|
||||
for i, (head_group, dep_group) in enumerate(zip(gold.heads, gold.labels)):
|
||||
if not USE_SPLIT:
|
||||
if isinstance(head_group, list):
|
||||
head_group = [(None, 0)]
|
||||
dep_group = [None]
|
||||
# Missing values
|
||||
subtok_label = self.strings['subtok']
|
||||
if USE_SPLIT:
|
||||
gold.resize_arrays(self.max_split * len(gold))
|
||||
# Subtokens are addressed by (subposition, position).
|
||||
# This way the 'normal' tokens (at subposition 0) occupy positions
|
||||
# 0...n in the array.
|
||||
for i in range(1, self.max_split-1):
|
||||
for j in range(len(gold)):
|
||||
index = i * len(gold) + j
|
||||
# If we've incorrectly split, we want to join them back
|
||||
# up -- so, set the head of each subtoken to the following
|
||||
# subtoken (until the end), and set the label to 'subtok'.
|
||||
gold.c.heads[index] = (i+1)*len(gold) + j
|
||||
gold.c.dep[index] = subtok_label
|
||||
gold.c.has_dep[index] = True
|
||||
for j in range(len(gold)):
|
||||
# For the last subtoken in each position, set head to 'unknown'.
|
||||
gold.c.heads[index] = index
|
||||
gold.c.deps[index] = 0
|
||||
gold.c.has_dep[index] = False
|
||||
for child_i, (head_group, dep_group) in enumerate(zip(gold.heads, gold.labels)):
|
||||
if not USE_SPLIT and isinstance(head_group, list):
|
||||
# Set as missing values if we don't handle token splitting
|
||||
head_group = [(None, 0)]
|
||||
dep_group = [None]
|
||||
if not isinstance(head_group, list):
|
||||
# Map the simple format into the elaborate one we need for
|
||||
# the fused tokens.
|
||||
head_group = [(head_group, 0)]
|
||||
dep_group = [dep_group]
|
||||
for head_addr, dep in zip(head_group, dep_group):
|
||||
for child_j, (head_addr, dep) in enumerate(zip(head_group, dep_group)):
|
||||
if not isinstance(head_addr, tuple):
|
||||
head_addr = (head_addr, 0)
|
||||
head, subtoken = head_addr
|
||||
if head is None or dep is None:
|
||||
gold.c.heads[i] = i
|
||||
gold.c.has_dep[i] = False
|
||||
head_i, head_j = head_addr
|
||||
child_index = child_j * len(gold) + child_i
|
||||
# Missing values
|
||||
if head_i is None or dep is None:
|
||||
gold.c.heads[child_index] = child_index
|
||||
gold.c.has_dep[child_index] = False
|
||||
continue
|
||||
if head > i:
|
||||
head_index = head_j * len(gold) + head_i
|
||||
if (head_i, head_j) > (child_i, child_j):
|
||||
action = LEFT
|
||||
elif head < i:
|
||||
elif (head_i, head_j) < (child_i, child_j):
|
||||
action = RIGHT
|
||||
else:
|
||||
action = BREAK
|
||||
|
@ -510,11 +543,11 @@ cdef class ArcEager(TransitionSystem):
|
|||
dep = 'dep'
|
||||
else:
|
||||
dep = 'dep'
|
||||
gold.c.has_dep[i] = True
|
||||
gold.c.has_dep[child_index] = True
|
||||
if dep.upper() == 'ROOT':
|
||||
dep = 'ROOT'
|
||||
gold.c.heads[i] = head
|
||||
gold.c.labels[i] = self.strings.add(dep)
|
||||
gold.c.heads[child_index] = head_index
|
||||
gold.c.labels[child_index] = self.strings.add(dep)
|
||||
return gold
|
||||
|
||||
def get_beam_parses(self, Beam beam):
|
||||
|
|
Loading…
Reference in New Issue
Block a user