diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 36730d49d..f61b15e7a 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -19,6 +19,7 @@ from ..structs cimport TokenC # Calculate cost as gold/not gold. We don't use scalar value anyway. cdef int BINARY_COSTS = 1 +cdef int MAX_SPLIT = 4 DEF NON_MONOTONIC = True DEF USE_BREAK = True @@ -437,12 +438,12 @@ cdef class ArcEager(TransitionSystem): # TODO: Split? return actions - property max_split: - def __get__(self): - return self.cfg.get('max_split', 0) + #property max_split: + # def __get__(self): + # return self.cfg.get('max_split', 0) - def __set__(self, int value): - self.cfg['max_split'] = value + # def __set__(self, int value): + # self.cfg['max_split'] = value property action_types: def __get__(self): @@ -464,14 +465,15 @@ cdef class ArcEager(TransitionSystem): predicted = set() truth = set() for i in range(gold.length): - if gold.cand_to_gold[i] is None: + gold_i = gold._alignment.index_to_yours(i) + if gold_i is None: continue if state.safe_get(i).dep: predicted.add((i, state.H(i), self.strings[state.safe_get(i).dep])) else: predicted.add((i, state.H(i), 'ROOT')) - id_, word, tag, head, dep, ner = gold.orig_annot[gold.cand_to_gold[i]] + id_, word, tag, head, dep, ner = gold.orig_annot[gold_i] truth.add((id_, head, dep)) return truth == predicted @@ -487,23 +489,23 @@ cdef class ArcEager(TransitionSystem): return None subtok_label = self.strings['subtok'] if USE_SPLIT: - gold.resize_arrays(self.max_split * len(gold)) + gold.resize_arrays(MAX_SPLIT * len(gold)) # Subtokens are addressed by (subposition, position). # This way the 'normal' tokens (at subposition 0) occupy positions # 0...n in the array. - for i in range(1, self.max_split-1): + for i in range(1, MAX_SPLIT-1): for j in range(len(gold)): index = i * len(gold) + j # If we've incorrectly split, we want to join them back # up -- so, set the head of each subtoken to the following # subtoken (until the end), and set the label to 'subtok'. gold.c.heads[index] = (i+1)*len(gold) + j - gold.c.dep[index] = subtok_label + gold.c.labels[index] = subtok_label gold.c.has_dep[index] = True for j in range(len(gold)): # For the last subtoken in each position, set head to 'unknown'. gold.c.heads[index] = index - gold.c.deps[index] = 0 + gold.c.labels[index] = 0 gold.c.has_dep[index] = False for child_i, (head_group, dep_group) in enumerate(zip(gold.heads, gold.labels)): if not USE_SPLIT and isinstance(head_group, list):