Update transition system

This commit is contained in:
Matthew Honnibal 2020-06-21 01:12:05 +02:00
parent 318a046fb0
commit 7544c21f5b

View File

@ -1,4 +1,5 @@
# cython: infer_types=True # cython: infer_types=True
from __future__ import print_function
from cpython.ref cimport Py_INCREF from cpython.ref cimport Py_INCREF
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
@ -67,11 +68,13 @@ cdef class TransitionSystem:
costs = <float*>mem.alloc(self.n_moves, sizeof(float)) costs = <float*>mem.alloc(self.n_moves, sizeof(float))
is_valid = <int*>mem.alloc(self.n_moves, sizeof(int)) is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))
cdef StateClass state = StateClass(example.predicted, offset=0) cdef StateClass state
self.initialize_state(state.c) states, golds, n_steps = self.init_gold_batch([example])
state = states[0]
gold = golds[0]
history = [] history = []
while not state.is_final(): while not state.is_final():
self.set_costs(is_valid, costs, state, example) self.set_costs(is_valid, costs, state, gold)
for i in range(self.n_moves): for i in range(self.n_moves):
if is_valid[i] and costs[i] <= 0: if is_valid[i] and costs[i] <= 0:
action = self.c[i] action = self.c[i]