spaCy/spacy/pipeline/_parser_internals/stateclass.pyx
2023-07-19 16:37:31 +02:00

185 lines
4.7 KiB
Cython

# cython: infer_types=True
from libcpp.vector cimport vector
from ...tokens.doc cimport Doc
from ._state cimport ArcC
cdef class StateClass:
def __init__(self, Doc doc=None, int offset=0):
self._borrowed = 0
if doc is not None:
self.c = new StateC(doc.c, doc.length)
self.c.offset = offset
self.doc = doc
else:
self.doc = None
def __dealloc__(self):
if self._borrowed != 1:
del self.c
@property
def history(self):
return list(self.c.history)
@property
def stack(self):
return [self.S(i) for i in range(self.c.stack_depth())]
@property
def queue(self):
return [self.B(i) for i in range(self.c.buffer_length())]
@property
def token_vector_lenth(self):
return self.doc.tensor.shape[1]
@property
def arcs(self):
cdef vector[ArcC] arcs
self.c.get_arcs(&arcs)
return list(arcs)
# py_arcs = []
# for arc in arcs:
# if arc.head != -1 and arc.child != -1:
# py_arcs.append((arc.head, arc.child, arc.label))
# return arcs
def add_arc(self, int head, int child, int label):
self.c.add_arc(head, child, label)
def del_arc(self, int head, int child):
self.c.del_arc(head, child)
def H(self, int child):
return self.c.H(child)
def L(self, int head, int idx):
return self.c.L(head, idx)
def R(self, int head, int idx):
return self.c.R(head, idx)
@property
def _b_i(self):
return self.c._b_i
@property
def length(self):
return self.c.length
def is_final(self):
return self.c.is_final()
def copy(self):
cdef StateClass new_state = StateClass(doc=self.doc, offset=self.c.offset)
new_state.c.clone(self.c)
return new_state
def print_state(self):
words = [token.text for token in self.doc]
words = list(words) + ['_']
bools = ["F", "T"]
sent_starts = [bools[self.c.is_sent_start(i)] for i in range(len(self.doc))]
shifted = [1 if self.c.is_unshiftable(i) else 0 for i in range(self.c.length)]
shifted.append("")
sent_starts.append("")
top = f"{self.S(0)}{words[self.S(0)]}_{words[self.H(self.S(0))]}_{shifted[self.S(0)]}"
second = f"{self.S(1)}{words[self.S(1)]}_{words[self.H(self.S(1))]}_{shifted[self.S(1)]}"
third = f"{self.S(2)}{words[self.S(2)]}_{words[self.H(self.S(2))]}_{shifted[self.S(2)]}"
n0 = f"{self.B(0)}{words[self.B(0)]}_{sent_starts[self.B(0)]}_{shifted[self.B(0)]}"
n1 = f"{self.B(1)}{words[self.B(1)]}_{sent_starts[self.B(1)]}_{shifted[self.B(1)]}"
return ' '.join((str(self.stack_depth()), str(self.buffer_length()), third, second, top, '|', n0, n1))
def S(self, int i):
return self.c.S(i)
def B(self, int i):
return self.c.B(i)
def H(self, int i):
return self.c.H(i)
def E(self, int i):
return self.c.E(i)
def L(self, int i, int idx):
return self.c.L(i, idx)
def R(self, int i, int idx):
return self.c.R(i, idx)
def S_(self, int i):
return self.doc[self.c.S(i)]
def B_(self, int i):
return self.doc[self.c.B(i)]
def H_(self, int i):
return self.doc[self.c.H(i)]
def E_(self, int i):
return self.doc[self.c.E(i)]
def L_(self, int i, int idx):
return self.doc[self.c.L(i, idx)]
def R_(self, int i, int idx):
return self.doc[self.c.R(i, idx)]
def empty(self):
return self.c.empty()
def eol(self):
return self.c.eol()
def at_break(self):
return False
# return self.c.at_break()
def has_head(self, int i):
return self.c.has_head(i)
def n_L(self, int i):
return self.c.n_L(i)
def n_R(self, int i):
return self.c.n_R(i)
def entity_is_open(self):
return self.c.entity_is_open()
def stack_depth(self):
return self.c.stack_depth()
def buffer_length(self):
return self.c.buffer_length()
def push(self):
self.c.push()
def pop(self):
self.c.pop()
def unshift(self):
self.c.unshift()
def add_arc(self, int head, int child, attr_t label):
self.c.add_arc(head, child, label)
def del_arc(self, int head, int child):
self.c.del_arc(head, child)
def open_ent(self, attr_t label):
self.c.open_ent(label)
def close_ent(self):
self.c.close_ent()
def clone(self, StateClass src):
self.c.clone(src.c)
def set_context_tokens(self, int[:, :] output, int row, int n_feats):
self.c.set_context_tokens(&output[row, 0], n_feats)