mirror of
https://github.com/explosion/spaCy.git
synced 2025-03-13 16:05:50 +03:00
Return kept examples from init_gold_batch
This commit is contained in:
parent
b3625dc697
commit
4925c0be34
|
@ -577,15 +577,16 @@ cdef class ArcEager(TransitionSystem):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def init_gold_batch(self, examples):
|
def init_gold_batch(self, examples):
|
||||||
examples = [eg for eg in examples if self.has_gold(eg)]
|
|
||||||
states = self.init_batch([eg.predicted for eg in examples])
|
states = self.init_batch([eg.predicted for eg in examples])
|
||||||
keeps = [i for i, s in enumerate(states) if not s.is_final()]
|
keeps = [i for i, (eg, s) in enumerate(zip(examples, states))
|
||||||
|
if self.has_gold(eg) and not s.is_final()]
|
||||||
golds = [ArcEagerGold(self, states[i], examples[i]) for i in keeps]
|
golds = [ArcEagerGold(self, states[i], examples[i]) for i in keeps]
|
||||||
|
examples = [examples[i] for i in keeps]
|
||||||
states = [states[i] for i in keeps]
|
states = [states[i] for i in keeps]
|
||||||
for gold in golds:
|
for gold in golds:
|
||||||
self._replace_unseen_labels(gold)
|
self._replace_unseen_labels(gold)
|
||||||
n_steps = sum([len(s.queue) * 4 for s in states])
|
n_steps = sum([len(s.queue) * 4 for s in states])
|
||||||
return states, golds, n_steps
|
return states, golds, examples, n_steps
|
||||||
|
|
||||||
def _replace_unseen_labels(self, ArcEagerGold gold):
|
def _replace_unseen_labels(self, ArcEagerGold gold):
|
||||||
backoff_label = self.strings["dep"]
|
backoff_label = self.strings["dep"]
|
||||||
|
|
|
@ -130,13 +130,13 @@ cdef class BiluoPushDown(TransitionSystem):
|
||||||
return MOVE_NAMES[move] + '-' + self.strings[label]
|
return MOVE_NAMES[move] + '-' + self.strings[label]
|
||||||
|
|
||||||
def init_gold_batch(self, examples):
|
def init_gold_batch(self, examples):
|
||||||
examples = [eg for eg in examples if self.has_gold(eg)]
|
|
||||||
states = self.init_batch([eg.predicted for eg in examples])
|
states = self.init_batch([eg.predicted for eg in examples])
|
||||||
keeps = [i for i, s in enumerate(states) if not s.is_final()]
|
keeps = [i for i, (s, eg) in enumerate(zip(states, examples))
|
||||||
|
if not s.is_final() and self.has_gold(eg)]
|
||||||
golds = [BiluoGold(self, states[i], examples[i]) for i in keeps]
|
golds = [BiluoGold(self, states[i], examples[i]) for i in keeps]
|
||||||
states = [states[i] for i in keeps]
|
states = [states[i] for i in keeps]
|
||||||
n_steps = sum([len(s.queue) for s in states])
|
n_steps = sum([len(s.queue) for s in states])
|
||||||
return states, golds, n_steps
|
return states, golds, examples, n_steps
|
||||||
|
|
||||||
cdef Transition lookup_transition(self, object name) except *:
|
cdef Transition lookup_transition(self, object name) except *:
|
||||||
cdef attr_t label
|
cdef attr_t label
|
||||||
|
@ -262,11 +262,11 @@ cdef class BiluoPushDown(TransitionSystem):
|
||||||
n_gold = 0
|
n_gold = 0
|
||||||
for i in range(self.n_moves):
|
for i in range(self.n_moves):
|
||||||
if self.c[i].is_valid(stcls.c, self.c[i].label):
|
if self.c[i].is_valid(stcls.c, self.c[i].label):
|
||||||
is_valid[i] = True
|
is_valid[i] = 1
|
||||||
costs[i] = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
|
costs[i] = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
|
||||||
n_gold += costs[i] <= 0
|
n_gold += costs[i] <= 0
|
||||||
else:
|
else:
|
||||||
is_valid[i] = False
|
is_valid[i] = 0
|
||||||
costs[i] = 9000
|
costs[i] = 9000
|
||||||
if n_gold < 1:
|
if n_gold < 1:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
Loading…
Reference in New Issue
Block a user