mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-10 15:14:56 +03:00
Fix feature-flagging of Split action
This commit is contained in:
parent
6cc79fc244
commit
4029dc2cc7
|
@ -19,14 +19,11 @@ from ..structs cimport TokenC
|
|||
|
||||
# Calculate cost as gold/not gold. We don't use scalar value anyway.
|
||||
cdef int BINARY_COSTS = 1
|
||||
cdef int MAX_SPLIT = 4
|
||||
|
||||
DEF NON_MONOTONIC = True
|
||||
DEF USE_BREAK = True
|
||||
DEF USE_SPLIT = False
|
||||
|
||||
cdef weight_t MIN_SCORE = -90000
|
||||
|
||||
# Sets NON_MONOTONIC, USE_BREAK, USE_SPLIT, MAX_SPLIT
|
||||
include "compile_time.pxi"
|
||||
|
||||
# Break transition inspired by this paper:
|
||||
# http://www.aclweb.org/anthology/P13-1074
|
||||
# The most relevant factor is whether we predict Break early, or late:
|
||||
|
@ -72,19 +69,22 @@ cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) no
|
|||
cost += Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0
|
||||
# If the token wasn't split before, but gold says it *should* be split,
|
||||
# don't push (split instead)
|
||||
if USE_SPLIT and not stcls.c.was_split[stcls.c.B(0)]:
|
||||
cost += gold.fused[stcls.c.B(0)]
|
||||
if USE_SPLIT:
|
||||
if not stcls.c.was_split[stcls.c.B(0)] and gold.fused[stcls.c.B(0)]:
|
||||
cost += 1
|
||||
return cost
|
||||
|
||||
|
||||
cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
|
||||
cdef weight_t cost = 0
|
||||
cdef int i, B_i
|
||||
# Take into account fused tokens
|
||||
cdef int target_token = target % stcls.c.length
|
||||
for i in range(stcls.c.segment_length()):
|
||||
B_i = stcls.B(i)
|
||||
cost += gold.heads[B_i] == target
|
||||
cost += gold.heads[target] == B_i
|
||||
if gold.heads[B_i] == B_i or gold.heads[B_i] < target:
|
||||
if gold.heads[B_i] == B_i or (gold.heads[B_i]%stcls.c.length) < target:
|
||||
break
|
||||
if BINARY_COSTS and cost >= 1:
|
||||
return cost
|
||||
|
@ -97,7 +97,7 @@ cdef weight_t arc_cost(StateClass stcls, const GoldParseC* gold, int head, int c
|
|||
elif stcls.H(child) == gold.heads[child]:
|
||||
return 1
|
||||
# Head in buffer
|
||||
elif gold.heads[child] >= stcls.B(0) and stcls.B(1) != 0:
|
||||
elif gold.heads[child] >= (stcls.B(0) % stcls.c.length) and stcls.B(1) != 0:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
@ -171,7 +171,7 @@ cdef class Split:
|
|||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
st.split(0, label)
|
||||
st.split(0, 1)
|
||||
|
||||
@staticmethod
|
||||
cdef weight_t cost(StateClass st, const GoldParseC* gold, attr_t label) nogil:
|
||||
|
@ -179,14 +179,18 @@ cdef class Split:
|
|||
|
||||
@staticmethod
|
||||
cdef weight_t move_cost(StateClass st, const GoldParseC* gold) nogil:
|
||||
if gold.fused[st.B(0)]:
|
||||
if not USE_SPLIT:
|
||||
return 9000
|
||||
elif gold.fused[st.B(0)]:
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
cdef weight_t label_cost(StateClass st, const GoldParseC* gold, attr_t label) nogil:
|
||||
if gold.fused[st.B(0)] == label:
|
||||
if not USE_SPLIT:
|
||||
return 9000
|
||||
elif gold.fused[st.B(0)] == 1: #label:
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
|
@ -305,7 +309,7 @@ cdef class RightArc:
|
|||
cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
|
||||
# If the token wasn't split before, but gold says it *should* be split,
|
||||
# don't right-arc (split instead)
|
||||
if not s.c.was_split[s.c.B(0)] and gold.fused[s.c.B(0)]:
|
||||
if USE_SPLIT and not s.c.was_split[s.c.B(0)] and gold.fused[s.c.B(0)]:
|
||||
return gold.fused[s.c.B(0)]
|
||||
elif arc_is_gold(gold, s.S(0), s.B(0)):
|
||||
return 0
|
||||
|
@ -438,16 +442,16 @@ cdef class ArcEager(TransitionSystem):
|
|||
# TODO: Split?
|
||||
return actions
|
||||
|
||||
#property max_split:
|
||||
# def __get__(self):
|
||||
# return self.cfg.get('max_split', 0)
|
||||
|
||||
# def __set__(self, int value):
|
||||
# self.cfg['max_split'] = value
|
||||
property max_split:
|
||||
def __get__(self):
|
||||
return MAX_SPLIT
|
||||
|
||||
property action_types:
|
||||
def __get__(self):
|
||||
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK, SPLIT)
|
||||
if USE_SPLIT:
|
||||
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK, SPLIT)
|
||||
else:
|
||||
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
|
||||
|
||||
def get_cost(self, StateClass state, GoldParse gold, action):
|
||||
cdef Transition t = self.lookup_transition(action)
|
||||
|
@ -493,7 +497,7 @@ cdef class ArcEager(TransitionSystem):
|
|||
# Subtokens are addressed by (subposition, position).
|
||||
# This way the 'normal' tokens (at subposition 0) occupy positions
|
||||
# 0...n in the array.
|
||||
for i in range(1, MAX_SPLIT-1):
|
||||
for i in range(1, MAX_SPLIT):
|
||||
for j in range(len(gold)):
|
||||
index = i * len(gold) + j
|
||||
# If we've incorrectly split, we want to join them back
|
||||
|
@ -507,6 +511,11 @@ cdef class ArcEager(TransitionSystem):
|
|||
gold.c.heads[index] = index
|
||||
gold.c.labels[index] = 0
|
||||
gold.c.has_dep[index] = False
|
||||
for i in range(len(gold)):
|
||||
if isinstance(gold.heads[i], list):
|
||||
gold.c.fused[i] = len(gold.heads)-1
|
||||
else:
|
||||
gold.c.fused[i] = 0
|
||||
for child_i, (head_group, dep_group) in enumerate(zip(gold.heads, gold.labels)):
|
||||
if not USE_SPLIT and (isinstance(head_group, list) or isinstance(head_group, tuple)):
|
||||
# Set as missing values if we don't handle token splitting
|
||||
|
@ -521,6 +530,9 @@ cdef class ArcEager(TransitionSystem):
|
|||
if not isinstance(head_addr, tuple):
|
||||
head_addr = (head_addr, 0)
|
||||
head_i, head_j = head_addr
|
||||
if not USE_SPLIT:
|
||||
head_j = 0
|
||||
child_j = 0
|
||||
child_index = child_j * len(gold) + child_i
|
||||
# Missing values
|
||||
if head_i is None or dep is None:
|
||||
|
|
Loading…
Reference in New Issue
Block a user