mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-30 20:06:30 +03:00
61 lines
1.9 KiB
Cython
61 lines
1.9 KiB
Cython
|
from __future__ import unicode_literals
|
||
|
|
||
|
from ._state cimport init_state
|
||
|
from ._state cimport entity_is_open
|
||
|
from .bilou_moves cimport fill_moves
|
||
|
from .bilou_moves cimport transition
|
||
|
from .bilou_moves cimport set_accept_if_valid, set_accept_if_oracle
|
||
|
from .bilou_moves import get_n_moves
|
||
|
from .bilou_moves import ACTION_NAMES
|
||
|
|
||
|
|
||
|
cdef class PyState:
|
||
|
def __init__(self, tag_names, n_tokens):
|
||
|
self.mem = Pool()
|
||
|
self.tag_names = tag_names
|
||
|
self.n_classes = len(tag_names)
|
||
|
assert self.n_classes != 0
|
||
|
self._moves = <Move*>self.mem.alloc(self.n_classes, sizeof(Move))
|
||
|
fill_moves(self._moves, tag_names)
|
||
|
self._s = init_state(self.mem, n_tokens)
|
||
|
self._golds = <Move*>self.mem.alloc(n_tokens, sizeof(Move))
|
||
|
|
||
|
cdef Move* _get_move(self, unicode move_name) except NULL:
|
||
|
return &self._moves[self.tag_names.index(move_name)]
|
||
|
|
||
|
def set_golds(self, list gold_names):
|
||
|
cdef Move* m
|
||
|
for i, name in enumerate(gold_names):
|
||
|
m = self._get_move(name)
|
||
|
self._golds[i] = m[0]
|
||
|
|
||
|
def transition(self, unicode move_name):
|
||
|
cdef Move* m = self._get_move(move_name)
|
||
|
transition(self._s, m)
|
||
|
|
||
|
def is_valid(self, unicode move_name):
|
||
|
cdef Move* m = self._get_move(move_name)
|
||
|
set_accept_if_valid(self._moves, self.n_classes, self._s)
|
||
|
return m.accept
|
||
|
|
||
|
def is_gold(self, unicode move_name):
|
||
|
cdef Move* m = self._get_move(move_name)
|
||
|
set_accept_if_oracle(self._moves, self._golds, self.n_classes, self._s)
|
||
|
return m.accept
|
||
|
|
||
|
property ent:
|
||
|
def __get__(self):
|
||
|
return self._s.curr
|
||
|
|
||
|
property n_ents:
|
||
|
def __get__(self):
|
||
|
return self._s.j
|
||
|
|
||
|
property i:
|
||
|
def __get__(self):
|
||
|
return self._s.i
|
||
|
|
||
|
property open_entity:
|
||
|
def __get__(self):
|
||
|
return entity_is_open(self._s)
|