Add is_gold_parse method to transition system

This commit is contained in:
Matthew Honnibal 2017-08-16 18:24:09 -05:00
parent 3533bb61cb
commit a6d8d7c82e
2 changed files with 17 additions and 0 deletions

View File

@ -351,6 +351,20 @@ cdef class ArcEager(TransitionSystem):
def __get__(self):
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
def is_gold_parse(self, StateClass state, GoldParse gold):
predicted = set()
truth = set()
for i in range(gold.length):
if gold.cand_to_gold[i] is None:
continue
if state.safe_get(i).dep:
predicted.add((i, state.H(i), self.strings[state.safe_get(i).dep]))
else:
predicted.add((i, state.H(i), 'ROOT'))
id_, word, tag, head, dep, ner = gold.orig_annot[gold.cand_to_gold[i]]
truth.add((id_, head, dep))
return truth == predicted
def has_gold(self, GoldParse gold, start=0, end=None):
end = end or len(gold.heads)
if all([tag is None for tag in gold.heads[start:end]]):

View File

@ -99,6 +99,9 @@ cdef class TransitionSystem:
def preprocess_gold(self, GoldParse gold):
raise NotImplementedError
def is_gold_parse(self, StateClass state, GoldParse gold):
raise NotImplementedError
cdef Transition lookup_transition(self, object name) except *:
raise NotImplementedError