mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Improve efficiency of get_oracle_sequences
This commit is contained in:
parent
233945bfe0
commit
57e09747dc
|
@ -742,21 +742,14 @@ cdef class ArcEager(TransitionSystem):
|
|||
if n_gold < 1:
|
||||
raise ValueError
|
||||
|
||||
def get_oracle_sequence(self, Example example):
|
||||
cdef StateClass state
|
||||
cdef ArcEagerGold gold
|
||||
states, golds, n_steps = self.init_gold_batch([example])
|
||||
if not golds:
|
||||
return []
|
||||
|
||||
def get_oracle_sequence_from_state(self, StateClass state, ArcEagerGold gold, _debug=None):
|
||||
cdef int i
|
||||
cdef Pool mem = Pool()
|
||||
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
||||
assert self.n_moves > 0
|
||||
costs = <float*>mem.alloc(self.n_moves, sizeof(float))
|
||||
is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))
|
||||
|
||||
state = states[0]
|
||||
gold = golds[0]
|
||||
history = []
|
||||
debug_log = []
|
||||
failed = False
|
||||
|
@ -772,18 +765,21 @@ cdef class ArcEager(TransitionSystem):
|
|||
history.append(i)
|
||||
s0 = state.S(0)
|
||||
b0 = state.B(0)
|
||||
debug_log.append(" ".join((
|
||||
self.get_class_name(i),
|
||||
"S0=", (example.x[s0].text if s0 >= 0 else "__"),
|
||||
"B0=", (example.x[b0].text if b0 >= 0 else "__"),
|
||||
"S0 head?", str(state.has_head(state.S(0))),
|
||||
)))
|
||||
if _debug:
|
||||
example = _debug
|
||||
debug_log.append(" ".join((
|
||||
self.get_class_name(i),
|
||||
"S0=", (example.x[s0].text if s0 >= 0 else "__"),
|
||||
"B0=", (example.x[b0].text if b0 >= 0 else "__"),
|
||||
"S0 head?", str(state.has_head(state.S(0))),
|
||||
)))
|
||||
action.do(state.c, action.label)
|
||||
break
|
||||
else:
|
||||
failed = False
|
||||
break
|
||||
if failed:
|
||||
example = _debug
|
||||
print("Actions")
|
||||
for i in range(self.n_moves):
|
||||
print(self.get_class_name(i))
|
||||
|
|
|
@ -63,7 +63,9 @@ cdef class Parser:
|
|||
self.model = model
|
||||
if self.moves.n_moves != 0:
|
||||
self.set_output(self.moves.n_moves)
|
||||
self.cfg = cfg
|
||||
self.cfg = dict(cfg)
|
||||
self.cfg.setdefault("update_with_oracle_cut_size", 100)
|
||||
self.cfg.setdefault("normalize_gradients_with_batch_size", True)
|
||||
self._multitasks = []
|
||||
for multitask in cfg.get("multitasks", []):
|
||||
self.add_multitask_objective(multitask)
|
||||
|
@ -272,13 +274,16 @@ cdef class Parser:
|
|||
# Prepare the stepwise model, and get the callback for finishing the batch
|
||||
model, backprop_tok2vec = self.model.begin_update(
|
||||
[eg.predicted for eg in examples])
|
||||
# Chop sequences into lengths of this many transitions, to make the
|
||||
# batch uniform length. We randomize this to overfit less.
|
||||
cut_gold = numpy.random.choice(range(20, 100))
|
||||
states, golds, max_steps = self._init_gold_batch(
|
||||
examples,
|
||||
max_length=cut_gold
|
||||
)
|
||||
if self.cfg["update_with_oracle_cut_size"] >= 1:
|
||||
# Chop sequences into lengths of this many transitions, to make the
|
||||
# batch uniform length. We randomize this to overfit less.
|
||||
cut_size = self.cfg["update_with_oracle_cut_size"]
|
||||
states, golds, max_steps = self._init_gold_batch(
|
||||
examples,
|
||||
max_length=numpy.random.choice(range(20, cut_size))
|
||||
)
|
||||
else:
|
||||
states, golds, max_steps = self.moves.init_gold_batch(examples)
|
||||
all_states = list(states)
|
||||
states_golds = zip(states, golds)
|
||||
for _ in range(max_steps):
|
||||
|
@ -384,7 +389,7 @@ cdef class Parser:
|
|||
cpu_log_loss(c_d_scores,
|
||||
costs, is_valid, &scores[i, 0], d_scores.shape[1])
|
||||
c_d_scores += d_scores.shape[1]
|
||||
if len(states):
|
||||
if len(states) and self.cfg["normalize_gradients_with_batch_size"]:
|
||||
d_scores /= len(states)
|
||||
if losses is not None:
|
||||
losses.setdefault(self.name, 0.)
|
||||
|
@ -516,7 +521,8 @@ cdef class Parser:
|
|||
states = []
|
||||
golds = []
|
||||
for eg, state, gold in kept:
|
||||
oracle_actions = self.moves.get_oracle_sequence(eg)
|
||||
oracle_actions = self.moves.get_oracle_sequence_from_state(
|
||||
state, gold)
|
||||
start = 0
|
||||
while start < len(eg.predicted):
|
||||
state = state.copy()
|
||||
|
|
|
@ -60,20 +60,25 @@ cdef class TransitionSystem:
|
|||
states.append(state)
|
||||
offset += len(doc)
|
||||
return states
|
||||
|
||||
|
||||
def get_oracle_sequence(self, Example example, _debug=False):
|
||||
states, golds, _ = self.init_gold_batch([example])
|
||||
if not states:
|
||||
return []
|
||||
state = states[0]
|
||||
gold = golds[0]
|
||||
if _debug:
|
||||
return self.get_oracle_sequence_from_state(state, gold, _debug=example)
|
||||
else:
|
||||
return self.get_oracle_sequence_from_state(state, gold)
|
||||
|
||||
def get_oracle_sequence_from_state(self, StateClass state, gold, _debug=None):
|
||||
cdef Pool mem = Pool()
|
||||
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
||||
assert self.n_moves > 0
|
||||
costs = <float*>mem.alloc(self.n_moves, sizeof(float))
|
||||
is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))
|
||||
|
||||
cdef StateClass state
|
||||
states, golds, n_steps = self.init_gold_batch([example])
|
||||
if not states:
|
||||
return []
|
||||
state = states[0]
|
||||
gold = golds[0]
|
||||
history = []
|
||||
debug_log = []
|
||||
while not state.is_final():
|
||||
|
@ -85,6 +90,7 @@ cdef class TransitionSystem:
|
|||
s0 = state.S(0)
|
||||
b0 = state.B(0)
|
||||
if _debug:
|
||||
example = _debug
|
||||
debug_log.append(" ".join((
|
||||
self.get_class_name(i),
|
||||
"S0=", (example.x[s0].text if s0 >= 0 else "__"),
|
||||
|
@ -95,6 +101,7 @@ cdef class TransitionSystem:
|
|||
break
|
||||
else:
|
||||
if _debug:
|
||||
example = _debug
|
||||
print("Actions")
|
||||
for i in range(self.n_moves):
|
||||
print(self.get_class_name(i))
|
||||
|
|
Loading…
Reference in New Issue
Block a user