mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-13 16:44:56 +03:00
Fix feature-flagging of Split action
This commit is contained in:
parent
6cc79fc244
commit
4029dc2cc7
|
@ -19,14 +19,11 @@ from ..structs cimport TokenC
|
||||||
|
|
||||||
# Calculate cost as gold/not gold. We don't use scalar value anyway.
|
# Calculate cost as gold/not gold. We don't use scalar value anyway.
|
||||||
cdef int BINARY_COSTS = 1
|
cdef int BINARY_COSTS = 1
|
||||||
cdef int MAX_SPLIT = 4
|
|
||||||
|
|
||||||
DEF NON_MONOTONIC = True
|
|
||||||
DEF USE_BREAK = True
|
|
||||||
DEF USE_SPLIT = False
|
|
||||||
|
|
||||||
cdef weight_t MIN_SCORE = -90000
|
cdef weight_t MIN_SCORE = -90000
|
||||||
|
|
||||||
|
# Sets NON_MONOTONIC, USE_BREAK, USE_SPLIT, MAX_SPLIT
|
||||||
|
include "compile_time.pxi"
|
||||||
|
|
||||||
# Break transition inspired by this paper:
|
# Break transition inspired by this paper:
|
||||||
# http://www.aclweb.org/anthology/P13-1074
|
# http://www.aclweb.org/anthology/P13-1074
|
||||||
# The most relevant factor is whether we predict Break early, or late:
|
# The most relevant factor is whether we predict Break early, or late:
|
||||||
|
@ -72,19 +69,22 @@ cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) no
|
||||||
cost += Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0
|
cost += Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0
|
||||||
# If the token wasn't split before, but gold says it *should* be split,
|
# If the token wasn't split before, but gold says it *should* be split,
|
||||||
# don't push (split instead)
|
# don't push (split instead)
|
||||||
if USE_SPLIT and not stcls.c.was_split[stcls.c.B(0)]:
|
if USE_SPLIT:
|
||||||
cost += gold.fused[stcls.c.B(0)]
|
if not stcls.c.was_split[stcls.c.B(0)] and gold.fused[stcls.c.B(0)]:
|
||||||
|
cost += 1
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
|
|
||||||
cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
|
cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
|
||||||
cdef weight_t cost = 0
|
cdef weight_t cost = 0
|
||||||
cdef int i, B_i
|
cdef int i, B_i
|
||||||
|
# Take into account fused tokens
|
||||||
|
cdef int target_token = target % stcls.c.length
|
||||||
for i in range(stcls.c.segment_length()):
|
for i in range(stcls.c.segment_length()):
|
||||||
B_i = stcls.B(i)
|
B_i = stcls.B(i)
|
||||||
cost += gold.heads[B_i] == target
|
cost += gold.heads[B_i] == target
|
||||||
cost += gold.heads[target] == B_i
|
cost += gold.heads[target] == B_i
|
||||||
if gold.heads[B_i] == B_i or gold.heads[B_i] < target:
|
if gold.heads[B_i] == B_i or (gold.heads[B_i]%stcls.c.length) < target:
|
||||||
break
|
break
|
||||||
if BINARY_COSTS and cost >= 1:
|
if BINARY_COSTS and cost >= 1:
|
||||||
return cost
|
return cost
|
||||||
|
@ -97,7 +97,7 @@ cdef weight_t arc_cost(StateClass stcls, const GoldParseC* gold, int head, int c
|
||||||
elif stcls.H(child) == gold.heads[child]:
|
elif stcls.H(child) == gold.heads[child]:
|
||||||
return 1
|
return 1
|
||||||
# Head in buffer
|
# Head in buffer
|
||||||
elif gold.heads[child] >= stcls.B(0) and stcls.B(1) != 0:
|
elif gold.heads[child] >= (stcls.B(0) % stcls.c.length) and stcls.B(1) != 0:
|
||||||
return 1
|
return 1
|
||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
@ -171,7 +171,7 @@ cdef class Split:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(StateC* st, attr_t label) nogil:
|
cdef int transition(StateC* st, attr_t label) nogil:
|
||||||
st.split(0, label)
|
st.split(0, 1)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef weight_t cost(StateClass st, const GoldParseC* gold, attr_t label) nogil:
|
cdef weight_t cost(StateClass st, const GoldParseC* gold, attr_t label) nogil:
|
||||||
|
@ -179,14 +179,18 @@ cdef class Split:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef weight_t move_cost(StateClass st, const GoldParseC* gold) nogil:
|
cdef weight_t move_cost(StateClass st, const GoldParseC* gold) nogil:
|
||||||
if gold.fused[st.B(0)]:
|
if not USE_SPLIT:
|
||||||
|
return 9000
|
||||||
|
elif gold.fused[st.B(0)]:
|
||||||
return 0
|
return 0
|
||||||
else:
|
else:
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef weight_t label_cost(StateClass st, const GoldParseC* gold, attr_t label) nogil:
|
cdef weight_t label_cost(StateClass st, const GoldParseC* gold, attr_t label) nogil:
|
||||||
if gold.fused[st.B(0)] == label:
|
if not USE_SPLIT:
|
||||||
|
return 9000
|
||||||
|
elif gold.fused[st.B(0)] == 1: #label:
|
||||||
return 0
|
return 0
|
||||||
else:
|
else:
|
||||||
return 1
|
return 1
|
||||||
|
@ -305,7 +309,7 @@ cdef class RightArc:
|
||||||
cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
|
cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
|
||||||
# If the token wasn't split before, but gold says it *should* be split,
|
# If the token wasn't split before, but gold says it *should* be split,
|
||||||
# don't right-arc (split instead)
|
# don't right-arc (split instead)
|
||||||
if not s.c.was_split[s.c.B(0)] and gold.fused[s.c.B(0)]:
|
if USE_SPLIT and not s.c.was_split[s.c.B(0)] and gold.fused[s.c.B(0)]:
|
||||||
return gold.fused[s.c.B(0)]
|
return gold.fused[s.c.B(0)]
|
||||||
elif arc_is_gold(gold, s.S(0), s.B(0)):
|
elif arc_is_gold(gold, s.S(0), s.B(0)):
|
||||||
return 0
|
return 0
|
||||||
|
@ -438,16 +442,16 @@ cdef class ArcEager(TransitionSystem):
|
||||||
# TODO: Split?
|
# TODO: Split?
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
#property max_split:
|
property max_split:
|
||||||
# def __get__(self):
|
def __get__(self):
|
||||||
# return self.cfg.get('max_split', 0)
|
return MAX_SPLIT
|
||||||
|
|
||||||
# def __set__(self, int value):
|
|
||||||
# self.cfg['max_split'] = value
|
|
||||||
|
|
||||||
property action_types:
|
property action_types:
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
|
if USE_SPLIT:
|
||||||
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK, SPLIT)
|
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK, SPLIT)
|
||||||
|
else:
|
||||||
|
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
|
||||||
|
|
||||||
def get_cost(self, StateClass state, GoldParse gold, action):
|
def get_cost(self, StateClass state, GoldParse gold, action):
|
||||||
cdef Transition t = self.lookup_transition(action)
|
cdef Transition t = self.lookup_transition(action)
|
||||||
|
@ -493,7 +497,7 @@ cdef class ArcEager(TransitionSystem):
|
||||||
# Subtokens are addressed by (subposition, position).
|
# Subtokens are addressed by (subposition, position).
|
||||||
# This way the 'normal' tokens (at subposition 0) occupy positions
|
# This way the 'normal' tokens (at subposition 0) occupy positions
|
||||||
# 0...n in the array.
|
# 0...n in the array.
|
||||||
for i in range(1, MAX_SPLIT-1):
|
for i in range(1, MAX_SPLIT):
|
||||||
for j in range(len(gold)):
|
for j in range(len(gold)):
|
||||||
index = i * len(gold) + j
|
index = i * len(gold) + j
|
||||||
# If we've incorrectly split, we want to join them back
|
# If we've incorrectly split, we want to join them back
|
||||||
|
@ -507,6 +511,11 @@ cdef class ArcEager(TransitionSystem):
|
||||||
gold.c.heads[index] = index
|
gold.c.heads[index] = index
|
||||||
gold.c.labels[index] = 0
|
gold.c.labels[index] = 0
|
||||||
gold.c.has_dep[index] = False
|
gold.c.has_dep[index] = False
|
||||||
|
for i in range(len(gold)):
|
||||||
|
if isinstance(gold.heads[i], list):
|
||||||
|
gold.c.fused[i] = len(gold.heads)-1
|
||||||
|
else:
|
||||||
|
gold.c.fused[i] = 0
|
||||||
for child_i, (head_group, dep_group) in enumerate(zip(gold.heads, gold.labels)):
|
for child_i, (head_group, dep_group) in enumerate(zip(gold.heads, gold.labels)):
|
||||||
if not USE_SPLIT and (isinstance(head_group, list) or isinstance(head_group, tuple)):
|
if not USE_SPLIT and (isinstance(head_group, list) or isinstance(head_group, tuple)):
|
||||||
# Set as missing values if we don't handle token splitting
|
# Set as missing values if we don't handle token splitting
|
||||||
|
@ -521,6 +530,9 @@ cdef class ArcEager(TransitionSystem):
|
||||||
if not isinstance(head_addr, tuple):
|
if not isinstance(head_addr, tuple):
|
||||||
head_addr = (head_addr, 0)
|
head_addr = (head_addr, 0)
|
||||||
head_i, head_j = head_addr
|
head_i, head_j = head_addr
|
||||||
|
if not USE_SPLIT:
|
||||||
|
head_j = 0
|
||||||
|
child_j = 0
|
||||||
child_index = child_j * len(gold) + child_i
|
child_index = child_j * len(gold) + child_i
|
||||||
# Missing values
|
# Missing values
|
||||||
if head_i is None or dep is None:
|
if head_i is None or dep is None:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user