Return kept examples from init_gold_batch

This commit is contained in:
Matthew Honnibal 2020-06-25 03:42:41 +02:00
parent b3625dc697
commit 4925c0be34
2 changed files with 9 additions and 8 deletions

View File

@ -577,15 +577,16 @@ cdef class ArcEager(TransitionSystem):
raise NotImplementedError raise NotImplementedError
def init_gold_batch(self, examples): def init_gold_batch(self, examples):
examples = [eg for eg in examples if self.has_gold(eg)]
states = self.init_batch([eg.predicted for eg in examples]) states = self.init_batch([eg.predicted for eg in examples])
keeps = [i for i, s in enumerate(states) if not s.is_final()] keeps = [i for i, (eg, s) in enumerate(zip(examples, states))
if self.has_gold(eg) and not s.is_final()]
golds = [ArcEagerGold(self, states[i], examples[i]) for i in keeps] golds = [ArcEagerGold(self, states[i], examples[i]) for i in keeps]
examples = [examples[i] for i in keeps]
states = [states[i] for i in keeps] states = [states[i] for i in keeps]
for gold in golds: for gold in golds:
self._replace_unseen_labels(gold) self._replace_unseen_labels(gold)
n_steps = sum([len(s.queue) * 4 for s in states]) n_steps = sum([len(s.queue) * 4 for s in states])
return states, golds, n_steps return states, golds, examples, n_steps
def _replace_unseen_labels(self, ArcEagerGold gold): def _replace_unseen_labels(self, ArcEagerGold gold):
backoff_label = self.strings["dep"] backoff_label = self.strings["dep"]

View File

@ -130,13 +130,13 @@ cdef class BiluoPushDown(TransitionSystem):
return MOVE_NAMES[move] + '-' + self.strings[label] return MOVE_NAMES[move] + '-' + self.strings[label]
def init_gold_batch(self, examples): def init_gold_batch(self, examples):
examples = [eg for eg in examples if self.has_gold(eg)]
states = self.init_batch([eg.predicted for eg in examples]) states = self.init_batch([eg.predicted for eg in examples])
keeps = [i for i, s in enumerate(states) if not s.is_final()] keeps = [i for i, (s, eg) in enumerate(zip(states, examples))
if not s.is_final() and self.has_gold(eg)]
golds = [BiluoGold(self, states[i], examples[i]) for i in keeps] golds = [BiluoGold(self, states[i], examples[i]) for i in keeps]
states = [states[i] for i in keeps] states = [states[i] for i in keeps]
n_steps = sum([len(s.queue) for s in states]) n_steps = sum([len(s.queue) for s in states])
return states, golds, n_steps return states, golds, examples, n_steps
cdef Transition lookup_transition(self, object name) except *: cdef Transition lookup_transition(self, object name) except *:
cdef attr_t label cdef attr_t label
@ -262,11 +262,11 @@ cdef class BiluoPushDown(TransitionSystem):
n_gold = 0 n_gold = 0
for i in range(self.n_moves): for i in range(self.n_moves):
if self.c[i].is_valid(stcls.c, self.c[i].label): if self.c[i].is_valid(stcls.c, self.c[i].label):
is_valid[i] = True is_valid[i] = 1
costs[i] = self.c[i].get_cost(stcls, &gold_state, self.c[i].label) costs[i] = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
n_gold += costs[i] <= 0 n_gold += costs[i] <= 0
else: else:
is_valid[i] = False is_valid[i] = 0
costs[i] = 9000 costs[i] = 9000
if n_gold < 1: if n_gold < 1:
raise ValueError raise ValueError