mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Exclude states with no matching gold annotations from parsing
This commit is contained in:
parent
83ffd16474
commit
e2136232f9
|
@ -350,7 +350,9 @@ cdef class ArcEager(TransitionSystem):
|
|||
def __get__(self):
|
||||
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
|
||||
|
||||
cdef int preprocess_gold(self, GoldParse gold) except -1:
|
||||
def preprocess_gold(self, GoldParse gold):
|
||||
if all([h is None for h in gold.heads]):
|
||||
return None
|
||||
for i in range(gold.length):
|
||||
if gold.heads[i] is None: # Missing values
|
||||
gold.c.heads[i] = i
|
||||
|
@ -361,6 +363,7 @@ cdef class ArcEager(TransitionSystem):
|
|||
label = 'ROOT'
|
||||
gold.c.heads[i] = gold.heads[i]
|
||||
gold.c.labels[i] = self.strings[label]
|
||||
return gold
|
||||
|
||||
cdef Transition lookup_transition(self, object name) except *:
|
||||
if '-' in name:
|
||||
|
|
|
@ -95,9 +95,12 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
else:
|
||||
return MOVE_NAMES[move] + '-' + self.strings[label]
|
||||
|
||||
cdef int preprocess_gold(self, GoldParse gold) except -1:
|
||||
def preprocess_gold(self, GoldParse gold):
|
||||
if all([tag == '-' for tag in gold.ner]):
|
||||
return None
|
||||
for i in range(gold.length):
|
||||
gold.c.ner[i] = self.lookup_transition(gold.ner[i])
|
||||
return gold
|
||||
|
||||
cdef Transition lookup_transition(self, object name) except *:
|
||||
if name == '-' or name == None:
|
||||
|
|
|
@ -318,15 +318,14 @@ cdef class Parser:
|
|||
golds = [golds]
|
||||
|
||||
cuda_stream = get_cuda_stream()
|
||||
for gold in golds:
|
||||
self.moves.preprocess_gold(gold)
|
||||
golds = [self.moves.preprocess_gold(g) for g in golds]
|
||||
|
||||
states = self.moves.init_batch(docs)
|
||||
state2vec, vec2scores = self.get_batch_model(len(states), tokvecs, cuda_stream,
|
||||
drop)
|
||||
|
||||
todo = [(s, g) for (s, g) in zip(states, golds)
|
||||
if not s.is_final()]
|
||||
if not s.is_final() and g is not None]
|
||||
|
||||
backprops = []
|
||||
cdef float loss = 0.
|
||||
|
|
|
@ -43,8 +43,6 @@ cdef class TransitionSystem:
|
|||
cdef int initialize_state(self, StateC* state) nogil
|
||||
cdef int finalize_state(self, StateC* state) nogil
|
||||
|
||||
cdef int preprocess_gold(self, GoldParse gold) except -1
|
||||
|
||||
cdef Transition lookup_transition(self, object name) except *
|
||||
|
||||
cdef Transition init_transition(self, int clas, int move, int label) except *
|
||||
|
|
|
@ -70,7 +70,7 @@ cdef class TransitionSystem:
|
|||
def finalize_doc(self, doc):
|
||||
pass
|
||||
|
||||
cdef int preprocess_gold(self, GoldParse gold) except -1:
|
||||
def preprocess_gold(self, GoldParse gold):
|
||||
raise NotImplementedError
|
||||
|
||||
cdef Transition lookup_transition(self, object name) except *:
|
||||
|
|
Loading…
Reference in New Issue
Block a user