diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 27fd88915..36730d49d 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -28,10 +28,11 @@ cdef weight_t MIN_SCORE = -90000 # Break transition inspired by this paper: # http://www.aclweb.org/anthology/P13-1074 -# However, there's a significant difference in the constraints. # The most relevant factor is whether we predict Break early, or late: # do we wait until the root is on the stack, or do we predict when the last # word of the previous sentence is on the stack? +# The paper applies Break early. This makes life harder, but we find it's +# worth it to give the model flexibility, and Break when stack may be deep. cdef enum: SHIFT REDUCE @@ -162,7 +163,7 @@ cdef class Split: return 0 elif st.buffer_length == 0: return 0 - elif st.is_split[st.B(0)]: + elif st.was_split[st.B(0)]: return 0 else: return 1 @@ -250,6 +251,7 @@ cdef class LeftArc: @staticmethod cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: + # TODO: Handle oracle for incorrect splits cdef weight_t move_cost = LeftArc.move_cost(s, gold) cdef weight_t label_cost = LeftArc.label_cost(s, gold, label) return move_cost + label_cost @@ -295,6 +297,7 @@ cdef class RightArc: @staticmethod cdef inline weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: + # TODO: Handle oracle for incorrect splits return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label) @staticmethod @@ -431,8 +434,16 @@ cdef class ArcEager(TransitionSystem): # Used for backoff actions[RIGHT].setdefault('dep', 0) actions[LEFT].setdefault('dep', 0) + # TODO: Split? return actions + property max_split: + def __get__(self): + return self.cfg.get('max_split', 0) + + def __set__(self, int value): + self.cfg['max_split'] = value + property action_types: def __get__(self): return (SHIFT, REDUCE, LEFT, RIGHT, BREAK, SPLIT) @@ -474,28 +485,50 @@ cdef class ArcEager(TransitionSystem): def preprocess_gold(self, GoldParse gold): if not self.has_gold(gold): return None - for i, (head_group, dep_group) in enumerate(zip(gold.heads, gold.labels)): - if not USE_SPLIT: - if isinstance(head_group, list): - head_group = [(None, 0)] - dep_group = [None] - # Missing values + subtok_label = self.strings['subtok'] + if USE_SPLIT: + gold.resize_arrays(self.max_split * len(gold)) + # Subtokens are addressed by (subposition, position). + # This way the 'normal' tokens (at subposition 0) occupy positions + # 0...n in the array. + for i in range(1, self.max_split-1): + for j in range(len(gold)): + index = i * len(gold) + j + # If we've incorrectly split, we want to join them back + # up -- so, set the head of each subtoken to the following + # subtoken (until the end), and set the label to 'subtok'. + gold.c.heads[index] = (i+1)*len(gold) + j + gold.c.dep[index] = subtok_label + gold.c.has_dep[index] = True + for j in range(len(gold)): + # For the last subtoken in each position, set head to 'unknown'. + gold.c.heads[index] = index + gold.c.deps[index] = 0 + gold.c.has_dep[index] = False + for child_i, (head_group, dep_group) in enumerate(zip(gold.heads, gold.labels)): + if not USE_SPLIT and isinstance(head_group, list): + # Set as missing values if we don't handle token splitting + head_group = [(None, 0)] + dep_group = [None] if not isinstance(head_group, list): # Map the simple format into the elaborate one we need for # the fused tokens. head_group = [(head_group, 0)] dep_group = [dep_group] - for head_addr, dep in zip(head_group, dep_group): + for child_j, (head_addr, dep) in enumerate(zip(head_group, dep_group)): if not isinstance(head_addr, tuple): head_addr = (head_addr, 0) - head, subtoken = head_addr - if head is None or dep is None: - gold.c.heads[i] = i - gold.c.has_dep[i] = False + head_i, head_j = head_addr + child_index = child_j * len(gold) + child_i + # Missing values + if head_i is None or dep is None: + gold.c.heads[child_index] = child_index + gold.c.has_dep[child_index] = False continue - if head > i: + head_index = head_j * len(gold) + head_i + if (head_i, head_j) > (child_i, child_j): action = LEFT - elif head < i: + elif (head_i, head_j) < (child_i, child_j): action = RIGHT else: action = BREAK @@ -510,11 +543,11 @@ cdef class ArcEager(TransitionSystem): dep = 'dep' else: dep = 'dep' - gold.c.has_dep[i] = True + gold.c.has_dep[child_index] = True if dep.upper() == 'ROOT': dep = 'ROOT' - gold.c.heads[i] = head - gold.c.labels[i] = self.strings.add(dep) + gold.c.heads[child_index] = head_index + gold.c.labels[child_index] = self.strings.add(dep) return gold def get_beam_parses(self, Beam beam):