mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 02:36:32 +03:00
Add is_gold_parse method to transition system
This commit is contained in:
parent
3533bb61cb
commit
a6d8d7c82e
|
@ -351,6 +351,20 @@ cdef class ArcEager(TransitionSystem):
|
|||
def __get__(self):
|
||||
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
|
||||
|
||||
def is_gold_parse(self, StateClass state, GoldParse gold):
|
||||
predicted = set()
|
||||
truth = set()
|
||||
for i in range(gold.length):
|
||||
if gold.cand_to_gold[i] is None:
|
||||
continue
|
||||
if state.safe_get(i).dep:
|
||||
predicted.add((i, state.H(i), self.strings[state.safe_get(i).dep]))
|
||||
else:
|
||||
predicted.add((i, state.H(i), 'ROOT'))
|
||||
id_, word, tag, head, dep, ner = gold.orig_annot[gold.cand_to_gold[i]]
|
||||
truth.add((id_, head, dep))
|
||||
return truth == predicted
|
||||
|
||||
def has_gold(self, GoldParse gold, start=0, end=None):
|
||||
end = end or len(gold.heads)
|
||||
if all([tag is None for tag in gold.heads[start:end]]):
|
||||
|
|
|
@ -99,6 +99,9 @@ cdef class TransitionSystem:
|
|||
def preprocess_gold(self, GoldParse gold):
|
||||
raise NotImplementedError
|
||||
|
||||
def is_gold_parse(self, StateClass state, GoldParse gold):
|
||||
raise NotImplementedError
|
||||
|
||||
cdef Transition lookup_transition(self, object name) except *:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user