From e2136232f9f46eb8b297ffd7441f68c8f5e7ebda Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 22 May 2017 10:30:12 -0500 Subject: [PATCH] Exclude states with no matching gold annotations from parsing --- spacy/syntax/arc_eager.pyx | 5 ++++- spacy/syntax/ner.pyx | 5 ++++- spacy/syntax/nn_parser.pyx | 5 ++--- spacy/syntax/transition_system.pxd | 2 -- spacy/syntax/transition_system.pyx | 2 +- 5 files changed, 11 insertions(+), 8 deletions(-) diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 2030a01ca..0a1422088 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -350,7 +350,9 @@ cdef class ArcEager(TransitionSystem): def __get__(self): return (SHIFT, REDUCE, LEFT, RIGHT, BREAK) - cdef int preprocess_gold(self, GoldParse gold) except -1: + def preprocess_gold(self, GoldParse gold): + if all([h is None for h in gold.heads]): + return None for i in range(gold.length): if gold.heads[i] is None: # Missing values gold.c.heads[i] = i @@ -361,6 +363,7 @@ cdef class ArcEager(TransitionSystem): label = 'ROOT' gold.c.heads[i] = gold.heads[i] gold.c.labels[i] = self.strings[label] + return gold cdef Transition lookup_transition(self, object name) except *: if '-' in name: diff --git a/spacy/syntax/ner.pyx b/spacy/syntax/ner.pyx index c2712c231..74ab9c26c 100644 --- a/spacy/syntax/ner.pyx +++ b/spacy/syntax/ner.pyx @@ -95,9 +95,12 @@ cdef class BiluoPushDown(TransitionSystem): else: return MOVE_NAMES[move] + '-' + self.strings[label] - cdef int preprocess_gold(self, GoldParse gold) except -1: + def preprocess_gold(self, GoldParse gold): + if all([tag == '-' for tag in gold.ner]): + return None for i in range(gold.length): gold.c.ner[i] = self.lookup_transition(gold.ner[i]) + return gold cdef Transition lookup_transition(self, object name) except *: if name == '-' or name == None: diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 81e44e84b..338ad2005 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -318,15 +318,14 @@ cdef class Parser: golds = [golds] cuda_stream = get_cuda_stream() - for gold in golds: - self.moves.preprocess_gold(gold) + golds = [self.moves.preprocess_gold(g) for g in golds] states = self.moves.init_batch(docs) state2vec, vec2scores = self.get_batch_model(len(states), tokvecs, cuda_stream, drop) todo = [(s, g) for (s, g) in zip(states, golds) - if not s.is_final()] + if not s.is_final() and g is not None] backprops = [] cdef float loss = 0. diff --git a/spacy/syntax/transition_system.pxd b/spacy/syntax/transition_system.pxd index 5169ff7ca..e61cf154c 100644 --- a/spacy/syntax/transition_system.pxd +++ b/spacy/syntax/transition_system.pxd @@ -43,8 +43,6 @@ cdef class TransitionSystem: cdef int initialize_state(self, StateC* state) nogil cdef int finalize_state(self, StateC* state) nogil - cdef int preprocess_gold(self, GoldParse gold) except -1 - cdef Transition lookup_transition(self, object name) except * cdef Transition init_transition(self, int clas, int move, int label) except * diff --git a/spacy/syntax/transition_system.pyx b/spacy/syntax/transition_system.pyx index 74b768dfb..d6750d09c 100644 --- a/spacy/syntax/transition_system.pyx +++ b/spacy/syntax/transition_system.pyx @@ -70,7 +70,7 @@ cdef class TransitionSystem: def finalize_doc(self, doc): pass - cdef int preprocess_gold(self, GoldParse gold) except -1: + def preprocess_gold(self, GoldParse gold): raise NotImplementedError cdef Transition lookup_transition(self, object name) except *: