From a6d8d7c82e269c2f91ca25ac73cc70dcc9cf890d Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 16 Aug 2017 18:24:09 -0500 Subject: [PATCH] Add is_gold_parse method to transition system --- spacy/syntax/arc_eager.pyx | 14 ++++++++++++++ spacy/syntax/transition_system.pyx | 3 +++ 2 files changed, 17 insertions(+) diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 9477449a5..aab350d76 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -351,6 +351,20 @@ cdef class ArcEager(TransitionSystem): def __get__(self): return (SHIFT, REDUCE, LEFT, RIGHT, BREAK) + def is_gold_parse(self, StateClass state, GoldParse gold): + predicted = set() + truth = set() + for i in range(gold.length): + if gold.cand_to_gold[i] is None: + continue + if state.safe_get(i).dep: + predicted.add((i, state.H(i), self.strings[state.safe_get(i).dep])) + else: + predicted.add((i, state.H(i), 'ROOT')) + id_, word, tag, head, dep, ner = gold.orig_annot[gold.cand_to_gold[i]] + truth.add((id_, head, dep)) + return truth == predicted + def has_gold(self, GoldParse gold, start=0, end=None): end = end or len(gold.heads) if all([tag is None for tag in gold.heads[start:end]]): diff --git a/spacy/syntax/transition_system.pyx b/spacy/syntax/transition_system.pyx index d3f64f827..9cf82e0c7 100644 --- a/spacy/syntax/transition_system.pyx +++ b/spacy/syntax/transition_system.pyx @@ -99,6 +99,9 @@ cdef class TransitionSystem: def preprocess_gold(self, GoldParse gold): raise NotImplementedError + def is_gold_parse(self, StateClass state, GoldParse gold): + raise NotImplementedError + cdef Transition lookup_transition(self, object name) except *: raise NotImplementedError