mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-11 09:00:36 +03:00
Update transition system
This commit is contained in:
parent
318a046fb0
commit
7544c21f5b
|
@ -1,4 +1,5 @@
|
||||||
# cython: infer_types=True
|
# cython: infer_types=True
|
||||||
|
from __future__ import print_function
|
||||||
from cpython.ref cimport Py_INCREF
|
from cpython.ref cimport Py_INCREF
|
||||||
from cymem.cymem cimport Pool
|
from cymem.cymem cimport Pool
|
||||||
|
|
||||||
|
@ -67,11 +68,13 @@ cdef class TransitionSystem:
|
||||||
costs = <float*>mem.alloc(self.n_moves, sizeof(float))
|
costs = <float*>mem.alloc(self.n_moves, sizeof(float))
|
||||||
is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))
|
is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))
|
||||||
|
|
||||||
cdef StateClass state = StateClass(example.predicted, offset=0)
|
cdef StateClass state
|
||||||
self.initialize_state(state.c)
|
states, golds, n_steps = self.init_gold_batch([example])
|
||||||
|
state = states[0]
|
||||||
|
gold = golds[0]
|
||||||
history = []
|
history = []
|
||||||
while not state.is_final():
|
while not state.is_final():
|
||||||
self.set_costs(is_valid, costs, state, example)
|
self.set_costs(is_valid, costs, state, gold)
|
||||||
for i in range(self.n_moves):
|
for i in range(self.n_moves):
|
||||||
if is_valid[i] and costs[i] <= 0:
|
if is_valid[i] and costs[i] <= 0:
|
||||||
action = self.c[i]
|
action = self.c[i]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user