Update ArcEager for changes to GoldParse class

This commit is contained in:
Matthew Honnibal 2018-04-02 23:53:13 +02:00
parent c9d314b7ba
commit aa5ecf7fd2

View File

@ -19,6 +19,7 @@ from ..structs cimport TokenC
# Calculate cost as gold/not gold. We don't use scalar value anyway. # Calculate cost as gold/not gold. We don't use scalar value anyway.
cdef int BINARY_COSTS = 1 cdef int BINARY_COSTS = 1
cdef int MAX_SPLIT = 4
DEF NON_MONOTONIC = True DEF NON_MONOTONIC = True
DEF USE_BREAK = True DEF USE_BREAK = True
@ -437,12 +438,12 @@ cdef class ArcEager(TransitionSystem):
# TODO: Split? # TODO: Split?
return actions return actions
property max_split: #property max_split:
def __get__(self): # def __get__(self):
return self.cfg.get('max_split', 0) # return self.cfg.get('max_split', 0)
def __set__(self, int value): # def __set__(self, int value):
self.cfg['max_split'] = value # self.cfg['max_split'] = value
property action_types: property action_types:
def __get__(self): def __get__(self):
@ -464,14 +465,15 @@ cdef class ArcEager(TransitionSystem):
predicted = set() predicted = set()
truth = set() truth = set()
for i in range(gold.length): for i in range(gold.length):
if gold.cand_to_gold[i] is None: gold_i = gold._alignment.index_to_yours(i)
if gold_i is None:
continue continue
if state.safe_get(i).dep: if state.safe_get(i).dep:
predicted.add((i, state.H(i), predicted.add((i, state.H(i),
self.strings[state.safe_get(i).dep])) self.strings[state.safe_get(i).dep]))
else: else:
predicted.add((i, state.H(i), 'ROOT')) predicted.add((i, state.H(i), 'ROOT'))
id_, word, tag, head, dep, ner = gold.orig_annot[gold.cand_to_gold[i]] id_, word, tag, head, dep, ner = gold.orig_annot[gold_i]
truth.add((id_, head, dep)) truth.add((id_, head, dep))
return truth == predicted return truth == predicted
@ -487,23 +489,23 @@ cdef class ArcEager(TransitionSystem):
return None return None
subtok_label = self.strings['subtok'] subtok_label = self.strings['subtok']
if USE_SPLIT: if USE_SPLIT:
gold.resize_arrays(self.max_split * len(gold)) gold.resize_arrays(MAX_SPLIT * len(gold))
# Subtokens are addressed by (subposition, position). # Subtokens are addressed by (subposition, position).
# This way the 'normal' tokens (at subposition 0) occupy positions # This way the 'normal' tokens (at subposition 0) occupy positions
# 0...n in the array. # 0...n in the array.
for i in range(1, self.max_split-1): for i in range(1, MAX_SPLIT-1):
for j in range(len(gold)): for j in range(len(gold)):
index = i * len(gold) + j index = i * len(gold) + j
# If we've incorrectly split, we want to join them back # If we've incorrectly split, we want to join them back
# up -- so, set the head of each subtoken to the following # up -- so, set the head of each subtoken to the following
# subtoken (until the end), and set the label to 'subtok'. # subtoken (until the end), and set the label to 'subtok'.
gold.c.heads[index] = (i+1)*len(gold) + j gold.c.heads[index] = (i+1)*len(gold) + j
gold.c.dep[index] = subtok_label gold.c.labels[index] = subtok_label
gold.c.has_dep[index] = True gold.c.has_dep[index] = True
for j in range(len(gold)): for j in range(len(gold)):
# For the last subtoken in each position, set head to 'unknown'. # For the last subtoken in each position, set head to 'unknown'.
gold.c.heads[index] = index gold.c.heads[index] = index
gold.c.deps[index] = 0 gold.c.labels[index] = 0
gold.c.has_dep[index] = False gold.c.has_dep[index] = False
for child_i, (head_group, dep_group) in enumerate(zip(gold.heads, gold.labels)): for child_i, (head_group, dep_group) in enumerate(zip(gold.heads, gold.labels)):
if not USE_SPLIT and isinstance(head_group, list): if not USE_SPLIT and isinstance(head_group, list):