Exclude states with no matching gold annotations from parsing

This commit is contained in:
Matthew Honnibal 2017-05-22 10:30:12 -05:00
parent 83ffd16474
commit e2136232f9
5 changed files with 11 additions and 8 deletions

View File

@ -350,7 +350,9 @@ cdef class ArcEager(TransitionSystem):
def __get__(self):
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
cdef int preprocess_gold(self, GoldParse gold) except -1:
def preprocess_gold(self, GoldParse gold):
if all([h is None for h in gold.heads]):
return None
for i in range(gold.length):
if gold.heads[i] is None: # Missing values
gold.c.heads[i] = i
@ -361,6 +363,7 @@ cdef class ArcEager(TransitionSystem):
label = 'ROOT'
gold.c.heads[i] = gold.heads[i]
gold.c.labels[i] = self.strings[label]
return gold
cdef Transition lookup_transition(self, object name) except *:
if '-' in name:

View File

@ -95,9 +95,12 @@ cdef class BiluoPushDown(TransitionSystem):
else:
return MOVE_NAMES[move] + '-' + self.strings[label]
cdef int preprocess_gold(self, GoldParse gold) except -1:
def preprocess_gold(self, GoldParse gold):
if all([tag == '-' for tag in gold.ner]):
return None
for i in range(gold.length):
gold.c.ner[i] = self.lookup_transition(gold.ner[i])
return gold
cdef Transition lookup_transition(self, object name) except *:
if name == '-' or name == None:

View File

@ -318,15 +318,14 @@ cdef class Parser:
golds = [golds]
cuda_stream = get_cuda_stream()
for gold in golds:
self.moves.preprocess_gold(gold)
golds = [self.moves.preprocess_gold(g) for g in golds]
states = self.moves.init_batch(docs)
state2vec, vec2scores = self.get_batch_model(len(states), tokvecs, cuda_stream,
drop)
todo = [(s, g) for (s, g) in zip(states, golds)
if not s.is_final()]
if not s.is_final() and g is not None]
backprops = []
cdef float loss = 0.

View File

@ -43,8 +43,6 @@ cdef class TransitionSystem:
cdef int initialize_state(self, StateC* state) nogil
cdef int finalize_state(self, StateC* state) nogil
cdef int preprocess_gold(self, GoldParse gold) except -1
cdef Transition lookup_transition(self, object name) except *
cdef Transition init_transition(self, int clas, int move, int label) except *

View File

@ -70,7 +70,7 @@ cdef class TransitionSystem:
def finalize_doc(self, doc):
pass
cdef int preprocess_gold(self, GoldParse gold) except -1:
def preprocess_gold(self, GoldParse gold):
raise NotImplementedError
cdef Transition lookup_transition(self, object name) except *: