mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-13 16:44:56 +03:00
Draft ArcEager.preprocess_gold for fused tokens
This commit is contained in:
parent
fb9c3984b5
commit
9c3612d40b
|
@ -28,10 +28,11 @@ cdef weight_t MIN_SCORE = -90000
|
||||||
|
|
||||||
# Break transition inspired by this paper:
|
# Break transition inspired by this paper:
|
||||||
# http://www.aclweb.org/anthology/P13-1074
|
# http://www.aclweb.org/anthology/P13-1074
|
||||||
# However, there's a significant difference in the constraints.
|
|
||||||
# The most relevant factor is whether we predict Break early, or late:
|
# The most relevant factor is whether we predict Break early, or late:
|
||||||
# do we wait until the root is on the stack, or do we predict when the last
|
# do we wait until the root is on the stack, or do we predict when the last
|
||||||
# word of the previous sentence is on the stack?
|
# word of the previous sentence is on the stack?
|
||||||
|
# The paper applies Break early. This makes life harder, but we find it's
|
||||||
|
# worth it to give the model flexibility, and Break when stack may be deep.
|
||||||
cdef enum:
|
cdef enum:
|
||||||
SHIFT
|
SHIFT
|
||||||
REDUCE
|
REDUCE
|
||||||
|
@ -162,7 +163,7 @@ cdef class Split:
|
||||||
return 0
|
return 0
|
||||||
elif st.buffer_length == 0:
|
elif st.buffer_length == 0:
|
||||||
return 0
|
return 0
|
||||||
elif st.is_split[st.B(0)]:
|
elif st.was_split[st.B(0)]:
|
||||||
return 0
|
return 0
|
||||||
else:
|
else:
|
||||||
return 1
|
return 1
|
||||||
|
@ -250,6 +251,7 @@ cdef class LeftArc:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
||||||
|
# TODO: Handle oracle for incorrect splits
|
||||||
cdef weight_t move_cost = LeftArc.move_cost(s, gold)
|
cdef weight_t move_cost = LeftArc.move_cost(s, gold)
|
||||||
cdef weight_t label_cost = LeftArc.label_cost(s, gold, label)
|
cdef weight_t label_cost = LeftArc.label_cost(s, gold, label)
|
||||||
return move_cost + label_cost
|
return move_cost + label_cost
|
||||||
|
@ -295,6 +297,7 @@ cdef class RightArc:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef inline weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
cdef inline weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
||||||
|
# TODO: Handle oracle for incorrect splits
|
||||||
return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
|
return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -431,8 +434,16 @@ cdef class ArcEager(TransitionSystem):
|
||||||
# Used for backoff
|
# Used for backoff
|
||||||
actions[RIGHT].setdefault('dep', 0)
|
actions[RIGHT].setdefault('dep', 0)
|
||||||
actions[LEFT].setdefault('dep', 0)
|
actions[LEFT].setdefault('dep', 0)
|
||||||
|
# TODO: Split?
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
|
property max_split:
|
||||||
|
def __get__(self):
|
||||||
|
return self.cfg.get('max_split', 0)
|
||||||
|
|
||||||
|
def __set__(self, int value):
|
||||||
|
self.cfg['max_split'] = value
|
||||||
|
|
||||||
property action_types:
|
property action_types:
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK, SPLIT)
|
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK, SPLIT)
|
||||||
|
@ -474,28 +485,50 @@ cdef class ArcEager(TransitionSystem):
|
||||||
def preprocess_gold(self, GoldParse gold):
|
def preprocess_gold(self, GoldParse gold):
|
||||||
if not self.has_gold(gold):
|
if not self.has_gold(gold):
|
||||||
return None
|
return None
|
||||||
for i, (head_group, dep_group) in enumerate(zip(gold.heads, gold.labels)):
|
subtok_label = self.strings['subtok']
|
||||||
if not USE_SPLIT:
|
if USE_SPLIT:
|
||||||
if isinstance(head_group, list):
|
gold.resize_arrays(self.max_split * len(gold))
|
||||||
head_group = [(None, 0)]
|
# Subtokens are addressed by (subposition, position).
|
||||||
dep_group = [None]
|
# This way the 'normal' tokens (at subposition 0) occupy positions
|
||||||
# Missing values
|
# 0...n in the array.
|
||||||
|
for i in range(1, self.max_split-1):
|
||||||
|
for j in range(len(gold)):
|
||||||
|
index = i * len(gold) + j
|
||||||
|
# If we've incorrectly split, we want to join them back
|
||||||
|
# up -- so, set the head of each subtoken to the following
|
||||||
|
# subtoken (until the end), and set the label to 'subtok'.
|
||||||
|
gold.c.heads[index] = (i+1)*len(gold) + j
|
||||||
|
gold.c.dep[index] = subtok_label
|
||||||
|
gold.c.has_dep[index] = True
|
||||||
|
for j in range(len(gold)):
|
||||||
|
# For the last subtoken in each position, set head to 'unknown'.
|
||||||
|
gold.c.heads[index] = index
|
||||||
|
gold.c.deps[index] = 0
|
||||||
|
gold.c.has_dep[index] = False
|
||||||
|
for child_i, (head_group, dep_group) in enumerate(zip(gold.heads, gold.labels)):
|
||||||
|
if not USE_SPLIT and isinstance(head_group, list):
|
||||||
|
# Set as missing values if we don't handle token splitting
|
||||||
|
head_group = [(None, 0)]
|
||||||
|
dep_group = [None]
|
||||||
if not isinstance(head_group, list):
|
if not isinstance(head_group, list):
|
||||||
# Map the simple format into the elaborate one we need for
|
# Map the simple format into the elaborate one we need for
|
||||||
# the fused tokens.
|
# the fused tokens.
|
||||||
head_group = [(head_group, 0)]
|
head_group = [(head_group, 0)]
|
||||||
dep_group = [dep_group]
|
dep_group = [dep_group]
|
||||||
for head_addr, dep in zip(head_group, dep_group):
|
for child_j, (head_addr, dep) in enumerate(zip(head_group, dep_group)):
|
||||||
if not isinstance(head_addr, tuple):
|
if not isinstance(head_addr, tuple):
|
||||||
head_addr = (head_addr, 0)
|
head_addr = (head_addr, 0)
|
||||||
head, subtoken = head_addr
|
head_i, head_j = head_addr
|
||||||
if head is None or dep is None:
|
child_index = child_j * len(gold) + child_i
|
||||||
gold.c.heads[i] = i
|
# Missing values
|
||||||
gold.c.has_dep[i] = False
|
if head_i is None or dep is None:
|
||||||
|
gold.c.heads[child_index] = child_index
|
||||||
|
gold.c.has_dep[child_index] = False
|
||||||
continue
|
continue
|
||||||
if head > i:
|
head_index = head_j * len(gold) + head_i
|
||||||
|
if (head_i, head_j) > (child_i, child_j):
|
||||||
action = LEFT
|
action = LEFT
|
||||||
elif head < i:
|
elif (head_i, head_j) < (child_i, child_j):
|
||||||
action = RIGHT
|
action = RIGHT
|
||||||
else:
|
else:
|
||||||
action = BREAK
|
action = BREAK
|
||||||
|
@ -510,11 +543,11 @@ cdef class ArcEager(TransitionSystem):
|
||||||
dep = 'dep'
|
dep = 'dep'
|
||||||
else:
|
else:
|
||||||
dep = 'dep'
|
dep = 'dep'
|
||||||
gold.c.has_dep[i] = True
|
gold.c.has_dep[child_index] = True
|
||||||
if dep.upper() == 'ROOT':
|
if dep.upper() == 'ROOT':
|
||||||
dep = 'ROOT'
|
dep = 'ROOT'
|
||||||
gold.c.heads[i] = head
|
gold.c.heads[child_index] = head_index
|
||||||
gold.c.labels[i] = self.strings.add(dep)
|
gold.c.labels[child_index] = self.strings.add(dep)
|
||||||
return gold
|
return gold
|
||||||
|
|
||||||
def get_beam_parses(self, Beam beam):
|
def get_beam_parses(self, Beam beam):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user