mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-03 12:43:15 +03:00
Improve efficiency of get_oracle_sequences
This commit is contained in:
parent
233945bfe0
commit
57e09747dc
|
@ -742,21 +742,14 @@ cdef class ArcEager(TransitionSystem):
|
||||||
if n_gold < 1:
|
if n_gold < 1:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
def get_oracle_sequence(self, Example example):
|
def get_oracle_sequence_from_state(self, StateClass state, ArcEagerGold gold, _debug=None):
|
||||||
cdef StateClass state
|
cdef int i
|
||||||
cdef ArcEagerGold gold
|
|
||||||
states, golds, n_steps = self.init_gold_batch([example])
|
|
||||||
if not golds:
|
|
||||||
return []
|
|
||||||
|
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
||||||
assert self.n_moves > 0
|
assert self.n_moves > 0
|
||||||
costs = <float*>mem.alloc(self.n_moves, sizeof(float))
|
costs = <float*>mem.alloc(self.n_moves, sizeof(float))
|
||||||
is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))
|
is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))
|
||||||
|
|
||||||
state = states[0]
|
|
||||||
gold = golds[0]
|
|
||||||
history = []
|
history = []
|
||||||
debug_log = []
|
debug_log = []
|
||||||
failed = False
|
failed = False
|
||||||
|
@ -772,6 +765,8 @@ cdef class ArcEager(TransitionSystem):
|
||||||
history.append(i)
|
history.append(i)
|
||||||
s0 = state.S(0)
|
s0 = state.S(0)
|
||||||
b0 = state.B(0)
|
b0 = state.B(0)
|
||||||
|
if _debug:
|
||||||
|
example = _debug
|
||||||
debug_log.append(" ".join((
|
debug_log.append(" ".join((
|
||||||
self.get_class_name(i),
|
self.get_class_name(i),
|
||||||
"S0=", (example.x[s0].text if s0 >= 0 else "__"),
|
"S0=", (example.x[s0].text if s0 >= 0 else "__"),
|
||||||
|
@ -784,6 +779,7 @@ cdef class ArcEager(TransitionSystem):
|
||||||
failed = False
|
failed = False
|
||||||
break
|
break
|
||||||
if failed:
|
if failed:
|
||||||
|
example = _debug
|
||||||
print("Actions")
|
print("Actions")
|
||||||
for i in range(self.n_moves):
|
for i in range(self.n_moves):
|
||||||
print(self.get_class_name(i))
|
print(self.get_class_name(i))
|
||||||
|
|
|
@ -63,7 +63,9 @@ cdef class Parser:
|
||||||
self.model = model
|
self.model = model
|
||||||
if self.moves.n_moves != 0:
|
if self.moves.n_moves != 0:
|
||||||
self.set_output(self.moves.n_moves)
|
self.set_output(self.moves.n_moves)
|
||||||
self.cfg = cfg
|
self.cfg = dict(cfg)
|
||||||
|
self.cfg.setdefault("update_with_oracle_cut_size", 100)
|
||||||
|
self.cfg.setdefault("normalize_gradients_with_batch_size", True)
|
||||||
self._multitasks = []
|
self._multitasks = []
|
||||||
for multitask in cfg.get("multitasks", []):
|
for multitask in cfg.get("multitasks", []):
|
||||||
self.add_multitask_objective(multitask)
|
self.add_multitask_objective(multitask)
|
||||||
|
@ -272,13 +274,16 @@ cdef class Parser:
|
||||||
# Prepare the stepwise model, and get the callback for finishing the batch
|
# Prepare the stepwise model, and get the callback for finishing the batch
|
||||||
model, backprop_tok2vec = self.model.begin_update(
|
model, backprop_tok2vec = self.model.begin_update(
|
||||||
[eg.predicted for eg in examples])
|
[eg.predicted for eg in examples])
|
||||||
|
if self.cfg["update_with_oracle_cut_size"] >= 1:
|
||||||
# Chop sequences into lengths of this many transitions, to make the
|
# Chop sequences into lengths of this many transitions, to make the
|
||||||
# batch uniform length. We randomize this to overfit less.
|
# batch uniform length. We randomize this to overfit less.
|
||||||
cut_gold = numpy.random.choice(range(20, 100))
|
cut_size = self.cfg["update_with_oracle_cut_size"]
|
||||||
states, golds, max_steps = self._init_gold_batch(
|
states, golds, max_steps = self._init_gold_batch(
|
||||||
examples,
|
examples,
|
||||||
max_length=cut_gold
|
max_length=numpy.random.choice(range(20, cut_size))
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
states, golds, max_steps = self.moves.init_gold_batch(examples)
|
||||||
all_states = list(states)
|
all_states = list(states)
|
||||||
states_golds = zip(states, golds)
|
states_golds = zip(states, golds)
|
||||||
for _ in range(max_steps):
|
for _ in range(max_steps):
|
||||||
|
@ -384,7 +389,7 @@ cdef class Parser:
|
||||||
cpu_log_loss(c_d_scores,
|
cpu_log_loss(c_d_scores,
|
||||||
costs, is_valid, &scores[i, 0], d_scores.shape[1])
|
costs, is_valid, &scores[i, 0], d_scores.shape[1])
|
||||||
c_d_scores += d_scores.shape[1]
|
c_d_scores += d_scores.shape[1]
|
||||||
if len(states):
|
if len(states) and self.cfg["normalize_gradients_with_batch_size"]:
|
||||||
d_scores /= len(states)
|
d_scores /= len(states)
|
||||||
if losses is not None:
|
if losses is not None:
|
||||||
losses.setdefault(self.name, 0.)
|
losses.setdefault(self.name, 0.)
|
||||||
|
@ -516,7 +521,8 @@ cdef class Parser:
|
||||||
states = []
|
states = []
|
||||||
golds = []
|
golds = []
|
||||||
for eg, state, gold in kept:
|
for eg, state, gold in kept:
|
||||||
oracle_actions = self.moves.get_oracle_sequence(eg)
|
oracle_actions = self.moves.get_oracle_sequence_from_state(
|
||||||
|
state, gold)
|
||||||
start = 0
|
start = 0
|
||||||
while start < len(eg.predicted):
|
while start < len(eg.predicted):
|
||||||
state = state.copy()
|
state = state.copy()
|
||||||
|
|
|
@ -62,18 +62,23 @@ cdef class TransitionSystem:
|
||||||
return states
|
return states
|
||||||
|
|
||||||
def get_oracle_sequence(self, Example example, _debug=False):
|
def get_oracle_sequence(self, Example example, _debug=False):
|
||||||
|
states, golds, _ = self.init_gold_batch([example])
|
||||||
|
if not states:
|
||||||
|
return []
|
||||||
|
state = states[0]
|
||||||
|
gold = golds[0]
|
||||||
|
if _debug:
|
||||||
|
return self.get_oracle_sequence_from_state(state, gold, _debug=example)
|
||||||
|
else:
|
||||||
|
return self.get_oracle_sequence_from_state(state, gold)
|
||||||
|
|
||||||
|
def get_oracle_sequence_from_state(self, StateClass state, gold, _debug=None):
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
||||||
assert self.n_moves > 0
|
assert self.n_moves > 0
|
||||||
costs = <float*>mem.alloc(self.n_moves, sizeof(float))
|
costs = <float*>mem.alloc(self.n_moves, sizeof(float))
|
||||||
is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))
|
is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))
|
||||||
|
|
||||||
cdef StateClass state
|
|
||||||
states, golds, n_steps = self.init_gold_batch([example])
|
|
||||||
if not states:
|
|
||||||
return []
|
|
||||||
state = states[0]
|
|
||||||
gold = golds[0]
|
|
||||||
history = []
|
history = []
|
||||||
debug_log = []
|
debug_log = []
|
||||||
while not state.is_final():
|
while not state.is_final():
|
||||||
|
@ -85,6 +90,7 @@ cdef class TransitionSystem:
|
||||||
s0 = state.S(0)
|
s0 = state.S(0)
|
||||||
b0 = state.B(0)
|
b0 = state.B(0)
|
||||||
if _debug:
|
if _debug:
|
||||||
|
example = _debug
|
||||||
debug_log.append(" ".join((
|
debug_log.append(" ".join((
|
||||||
self.get_class_name(i),
|
self.get_class_name(i),
|
||||||
"S0=", (example.x[s0].text if s0 >= 0 else "__"),
|
"S0=", (example.x[s0].text if s0 >= 0 else "__"),
|
||||||
|
@ -95,6 +101,7 @@ cdef class TransitionSystem:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
if _debug:
|
if _debug:
|
||||||
|
example = _debug
|
||||||
print("Actions")
|
print("Actions")
|
||||||
for i in range(self.n_moves):
|
for i in range(self.n_moves):
|
||||||
print(self.get_class_name(i))
|
print(self.get_class_name(i))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user