* Implement validation and oracle on pystate, for testing

This commit is contained in:
Matthew Honnibal 2014-11-10 22:15:18 +11:00
parent 3709ed9d6d
commit 82247169f2
2 changed files with 15 additions and 6 deletions

View File

@ -11,6 +11,7 @@ cdef class PyState:
cdef readonly dict moves_by_name
cdef Move* _moves
cdef Move* _golds
cdef State* _s
cdef Move* _get_move(self, unicode move_name) except NULL

View File

@ -4,7 +4,7 @@ from ._state cimport init_state
from ._state cimport entity_is_open
from .moves cimport fill_moves
from .moves cimport transition
from .moves cimport set_accept_if_valid
from .moves cimport set_accept_if_valid, set_accept_if_oracle
from .moves import get_n_moves
from .moves import ACTION_NAMES
@ -27,10 +27,18 @@ cdef class PyState:
else:
tag_name = tag_names[m.label]
self.moves_by_name['%s-%s' % (action_name, tag_name)] = i
# TODO
self._golds = <Move*>self.mem.alloc(n_tokens, sizeof(Move))
cdef Move* _get_move(self, unicode move_name) except NULL:
return &self._moves[self.moves_by_name[move_name]]
def set_golds(self, list gold_names):
cdef Move* m
for i, name in enumerate(gold_names):
m = self._get_move(name)
self._golds[i] = m[0]
def transition(self, unicode move_name):
cdef Move* m = self._get_move(move_name)
transition(self._s, m)
@ -41,15 +49,17 @@ cdef class PyState:
return m.accept
def is_gold(self, unicode move_name):
pass
set_accept_if_oracle(self._moves, self._golds, self.n_classes, self._s)
cdef Move* m = self._get_move(move_name)
return m.accept
property ent:
def __get__(self):
return self._s.ents[self._s.j]
return self._s.curr
property n_ents:
def __get__(self):
return self._s.j + 1
return self._s.j
property i:
def __get__(self):
@ -58,5 +68,3 @@ cdef class PyState:
property open_entity:
def __get__(self):
return entity_is_open(self._s)