Update ArcEager for changes to GoldParse class

This commit is contained in:
Matthew Honnibal 2018-04-02 23:53:13 +02:00
parent c9d314b7ba
commit aa5ecf7fd2

View File

@ -19,6 +19,7 @@ from ..structs cimport TokenC
# Calculate cost as gold/not gold. We don't use scalar value anyway.
cdef int BINARY_COSTS = 1
cdef int MAX_SPLIT = 4
DEF NON_MONOTONIC = True
DEF USE_BREAK = True
@ -437,12 +438,12 @@ cdef class ArcEager(TransitionSystem):
# TODO: Split?
return actions
property max_split:
def __get__(self):
return self.cfg.get('max_split', 0)
#property max_split:
# def __get__(self):
# return self.cfg.get('max_split', 0)
def __set__(self, int value):
self.cfg['max_split'] = value
# def __set__(self, int value):
# self.cfg['max_split'] = value
property action_types:
def __get__(self):
@ -464,14 +465,15 @@ cdef class ArcEager(TransitionSystem):
predicted = set()
truth = set()
for i in range(gold.length):
if gold.cand_to_gold[i] is None:
gold_i = gold._alignment.index_to_yours(i)
if gold_i is None:
continue
if state.safe_get(i).dep:
predicted.add((i, state.H(i),
self.strings[state.safe_get(i).dep]))
else:
predicted.add((i, state.H(i), 'ROOT'))
id_, word, tag, head, dep, ner = gold.orig_annot[gold.cand_to_gold[i]]
id_, word, tag, head, dep, ner = gold.orig_annot[gold_i]
truth.add((id_, head, dep))
return truth == predicted
@ -487,23 +489,23 @@ cdef class ArcEager(TransitionSystem):
return None
subtok_label = self.strings['subtok']
if USE_SPLIT:
gold.resize_arrays(self.max_split * len(gold))
gold.resize_arrays(MAX_SPLIT * len(gold))
# Subtokens are addressed by (subposition, position).
# This way the 'normal' tokens (at subposition 0) occupy positions
# 0...n in the array.
for i in range(1, self.max_split-1):
for i in range(1, MAX_SPLIT-1):
for j in range(len(gold)):
index = i * len(gold) + j
# If we've incorrectly split, we want to join them back
# up -- so, set the head of each subtoken to the following
# subtoken (until the end), and set the label to 'subtok'.
gold.c.heads[index] = (i+1)*len(gold) + j
gold.c.dep[index] = subtok_label
gold.c.labels[index] = subtok_label
gold.c.has_dep[index] = True
for j in range(len(gold)):
# For the last subtoken in each position, set head to 'unknown'.
gold.c.heads[index] = index
gold.c.deps[index] = 0
gold.c.labels[index] = 0
gold.c.has_dep[index] = False
for child_i, (head_group, dep_group) in enumerate(zip(gold.heads, gold.labels)):
if not USE_SPLIT and isinstance(head_group, list):