Clean up debugging

This commit is contained in:
Matthew Honnibal 2020-06-22 15:34:34 +02:00
parent b250f6b62f
commit e92be79ffc

View File

@ -61,7 +61,7 @@ cdef class TransitionSystem:
offset += len(doc) offset += len(doc)
return states return states
def get_oracle_sequence(self, Example example): def get_oracle_sequence(self, Example example, _debug=False):
cdef Pool mem = Pool() cdef Pool mem = Pool()
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc # n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
assert self.n_moves > 0 assert self.n_moves > 0
@ -70,6 +70,8 @@ cdef class TransitionSystem:
cdef StateClass state cdef StateClass state
states, golds, n_steps = self.init_gold_batch([example]) states, golds, n_steps = self.init_gold_batch([example])
if not states:
return []
state = states[0] state = states[0]
gold = golds[0] gold = golds[0]
history = [] history = []
@ -82,6 +84,7 @@ cdef class TransitionSystem:
history.append(i) history.append(i)
s0 = state.S(0) s0 = state.S(0)
b0 = state.B(0) b0 = state.B(0)
if _debug:
debug_log.append(" ".join(( debug_log.append(" ".join((
self.get_class_name(i), self.get_class_name(i),
"S0=", (example.x[s0].text if s0 >= 0 else "__"), "S0=", (example.x[s0].text if s0 >= 0 else "__"),
@ -91,6 +94,7 @@ cdef class TransitionSystem:
action.do(state.c, action.label) action.do(state.c, action.label)
break break
else: else:
if _debug:
print("Actions") print("Actions")
for i in range(self.n_moves): for i in range(self.n_moves):
print(self.get_class_name(i)) print(self.get_class_name(i))