From 4029dc2cc7e79a968a1cd79bece3013760846c3b Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 3 Apr 2018 15:43:50 +0200 Subject: [PATCH] Fix feature-flagging of Split action --- spacy/syntax/arc_eager.pyx | 56 +++++++++++++++++++++++--------------- 1 file changed, 34 insertions(+), 22 deletions(-) diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 37278a6f3..1b03208ba 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -19,14 +19,11 @@ from ..structs cimport TokenC # Calculate cost as gold/not gold. We don't use scalar value anyway. cdef int BINARY_COSTS = 1 -cdef int MAX_SPLIT = 4 - -DEF NON_MONOTONIC = True -DEF USE_BREAK = True -DEF USE_SPLIT = False - cdef weight_t MIN_SCORE = -90000 +# Sets NON_MONOTONIC, USE_BREAK, USE_SPLIT, MAX_SPLIT +include "compile_time.pxi" + # Break transition inspired by this paper: # http://www.aclweb.org/anthology/P13-1074 # The most relevant factor is whether we predict Break early, or late: @@ -72,19 +69,22 @@ cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) no cost += Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0 # If the token wasn't split before, but gold says it *should* be split, # don't push (split instead) - if USE_SPLIT and not stcls.c.was_split[stcls.c.B(0)]: - cost += gold.fused[stcls.c.B(0)] + if USE_SPLIT: + if not stcls.c.was_split[stcls.c.B(0)] and gold.fused[stcls.c.B(0)]: + cost += 1 return cost cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil: cdef weight_t cost = 0 cdef int i, B_i + # Take into account fused tokens + cdef int target_token = target % stcls.c.length for i in range(stcls.c.segment_length()): B_i = stcls.B(i) cost += gold.heads[B_i] == target cost += gold.heads[target] == B_i - if gold.heads[B_i] == B_i or gold.heads[B_i] < target: + if gold.heads[B_i] == B_i or (gold.heads[B_i]%stcls.c.length) < target: break if BINARY_COSTS and cost >= 1: return cost @@ -97,7 +97,7 @@ cdef weight_t arc_cost(StateClass stcls, const GoldParseC* gold, int head, int c elif stcls.H(child) == gold.heads[child]: return 1 # Head in buffer - elif gold.heads[child] >= stcls.B(0) and stcls.B(1) != 0: + elif gold.heads[child] >= (stcls.B(0) % stcls.c.length) and stcls.B(1) != 0: return 1 else: return 0 @@ -171,7 +171,7 @@ cdef class Split: @staticmethod cdef int transition(StateC* st, attr_t label) nogil: - st.split(0, label) + st.split(0, 1) @staticmethod cdef weight_t cost(StateClass st, const GoldParseC* gold, attr_t label) nogil: @@ -179,14 +179,18 @@ cdef class Split: @staticmethod cdef weight_t move_cost(StateClass st, const GoldParseC* gold) nogil: - if gold.fused[st.B(0)]: + if not USE_SPLIT: + return 9000 + elif gold.fused[st.B(0)]: return 0 else: return 1 @staticmethod cdef weight_t label_cost(StateClass st, const GoldParseC* gold, attr_t label) nogil: - if gold.fused[st.B(0)] == label: + if not USE_SPLIT: + return 9000 + elif gold.fused[st.B(0)] == 1: #label: return 0 else: return 1 @@ -305,7 +309,7 @@ cdef class RightArc: cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil: # If the token wasn't split before, but gold says it *should* be split, # don't right-arc (split instead) - if not s.c.was_split[s.c.B(0)] and gold.fused[s.c.B(0)]: + if USE_SPLIT and not s.c.was_split[s.c.B(0)] and gold.fused[s.c.B(0)]: return gold.fused[s.c.B(0)] elif arc_is_gold(gold, s.S(0), s.B(0)): return 0 @@ -438,16 +442,16 @@ cdef class ArcEager(TransitionSystem): # TODO: Split? return actions - #property max_split: - # def __get__(self): - # return self.cfg.get('max_split', 0) - - # def __set__(self, int value): - # self.cfg['max_split'] = value + property max_split: + def __get__(self): + return MAX_SPLIT property action_types: def __get__(self): - return (SHIFT, REDUCE, LEFT, RIGHT, BREAK, SPLIT) + if USE_SPLIT: + return (SHIFT, REDUCE, LEFT, RIGHT, BREAK, SPLIT) + else: + return (SHIFT, REDUCE, LEFT, RIGHT, BREAK) def get_cost(self, StateClass state, GoldParse gold, action): cdef Transition t = self.lookup_transition(action) @@ -493,7 +497,7 @@ cdef class ArcEager(TransitionSystem): # Subtokens are addressed by (subposition, position). # This way the 'normal' tokens (at subposition 0) occupy positions # 0...n in the array. - for i in range(1, MAX_SPLIT-1): + for i in range(1, MAX_SPLIT): for j in range(len(gold)): index = i * len(gold) + j # If we've incorrectly split, we want to join them back @@ -507,6 +511,11 @@ cdef class ArcEager(TransitionSystem): gold.c.heads[index] = index gold.c.labels[index] = 0 gold.c.has_dep[index] = False + for i in range(len(gold)): + if isinstance(gold.heads[i], list): + gold.c.fused[i] = len(gold.heads)-1 + else: + gold.c.fused[i] = 0 for child_i, (head_group, dep_group) in enumerate(zip(gold.heads, gold.labels)): if not USE_SPLIT and (isinstance(head_group, list) or isinstance(head_group, tuple)): # Set as missing values if we don't handle token splitting @@ -521,6 +530,9 @@ cdef class ArcEager(TransitionSystem): if not isinstance(head_addr, tuple): head_addr = (head_addr, 0) head_i, head_j = head_addr + if not USE_SPLIT: + head_j = 0 + child_j = 0 child_index = child_j * len(gold) + child_i # Missing values if head_i is None or dep is None: