mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 08:12:24 +03:00
Add is_gold_parse method to transition system
This commit is contained in:
parent
3533bb61cb
commit
a6d8d7c82e
|
@ -351,6 +351,20 @@ cdef class ArcEager(TransitionSystem):
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
|
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
|
||||||
|
|
||||||
|
def is_gold_parse(self, StateClass state, GoldParse gold):
|
||||||
|
predicted = set()
|
||||||
|
truth = set()
|
||||||
|
for i in range(gold.length):
|
||||||
|
if gold.cand_to_gold[i] is None:
|
||||||
|
continue
|
||||||
|
if state.safe_get(i).dep:
|
||||||
|
predicted.add((i, state.H(i), self.strings[state.safe_get(i).dep]))
|
||||||
|
else:
|
||||||
|
predicted.add((i, state.H(i), 'ROOT'))
|
||||||
|
id_, word, tag, head, dep, ner = gold.orig_annot[gold.cand_to_gold[i]]
|
||||||
|
truth.add((id_, head, dep))
|
||||||
|
return truth == predicted
|
||||||
|
|
||||||
def has_gold(self, GoldParse gold, start=0, end=None):
|
def has_gold(self, GoldParse gold, start=0, end=None):
|
||||||
end = end or len(gold.heads)
|
end = end or len(gold.heads)
|
||||||
if all([tag is None for tag in gold.heads[start:end]]):
|
if all([tag is None for tag in gold.heads[start:end]]):
|
||||||
|
|
|
@ -99,6 +99,9 @@ cdef class TransitionSystem:
|
||||||
def preprocess_gold(self, GoldParse gold):
|
def preprocess_gold(self, GoldParse gold):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def is_gold_parse(self, StateClass state, GoldParse gold):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
cdef Transition lookup_transition(self, object name) except *:
|
cdef Transition lookup_transition(self, object name) except *:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user