# cython: infer_types=True
import numpy
from libcpp.vector cimport vector
from ._state cimport ArcC

from ...tokens.doc cimport Doc


cdef class StateClass:
    def __init__(self, Doc doc=None, int offset=0):
        self._borrowed = 0
        if doc is not None:
            self.c = new StateC(doc.c, doc.length)
            self.c.offset = offset
            self.doc = doc
        else:
            self.doc = None

    def __dealloc__(self):
        if self._borrowed != 1:
            del self.c

    @property
    def stack(self):
        return [self.S(i) for i in range(self.c.stack_depth())]

    @property
    def queue(self):
        return [self.B(i) for i in range(self.c.buffer_length())]

    @property
    def token_vector_lenth(self):
        return self.doc.tensor.shape[1]

    @property
    def arcs(self):
        cdef vector[ArcC] arcs
        self.c.get_arcs(&arcs)
        return list(arcs)
        #py_arcs = []
        #for arc in arcs:
        #    if arc.head != -1 and arc.child != -1:
        #        py_arcs.append((arc.head, arc.child, arc.label))
        #return arcs

    def add_arc(self, int head, int child, int label):
        self.c.add_arc(head, child, label)

    def del_arc(self, int head, int child):
        self.c.del_arc(head, child)

    def H(self, int child):
        return self.c.H(child)
    
    def L(self, int head, int idx):
        return self.c.L(head, idx)
    
    def R(self, int head, int idx):
        return self.c.R(head, idx)

    @property
    def _b_i(self):
        return self.c._b_i

    @property
    def length(self):
        return self.c.length

    def is_final(self):
        return self.c.is_final()

    def copy(self):
        cdef StateClass new_state = StateClass(doc=self.doc, offset=self.c.offset)
        new_state.c.clone(self.c)
        return new_state

    def print_state(self):
        words = [token.text for token in self.doc]
        words = list(words) + ['_']
        bools = ["F", "T"]
        sent_starts = [bools[self.c.is_sent_start(i)] for i in range(len(self.doc))]
        shifted = [1 if self.c.is_unshiftable(i) else 0 for i in range(self.c.length)]
        shifted.append("")
        sent_starts.append("")
        top = f"{self.S(0)}{words[self.S(0)]}_{words[self.H(self.S(0))]}_{shifted[self.S(0)]}"
        second = f"{self.S(1)}{words[self.S(1)]}_{words[self.H(self.S(1))]}_{shifted[self.S(1)]}"
        third = f"{self.S(2)}{words[self.S(2)]}_{words[self.H(self.S(2))]}_{shifted[self.S(2)]}"
        n0 = f"{self.B(0)}{words[self.B(0)]}_{sent_starts[self.B(0)]}_{shifted[self.B(0)]}"
        n1 = f"{self.B(1)}{words[self.B(1)]}_{sent_starts[self.B(1)]}_{shifted[self.B(1)]}"
        return ' '.join((str(self.stack_depth()), str(self.buffer_length()), third, second, top, '|', n0, n1))

    def S(self, int i):
        return self.c.S(i)

    def B(self, int i):
        return self.c.B(i)

    def H(self, int i):
        return self.c.H(i)
    
    def E(self, int i):
        return self.c.E(i)

    def L(self, int i, int idx):
        return self.c.L(i, idx)

    def R(self, int i, int idx):
        return self.c.R(i, idx)

    def S_(self, int i):
        return self.doc[self.c.S(i)]

    def B_(self, int i):
        return self.doc[self.c.B(i)]

    def H_(self, int i):
        return self.doc[self.c.H(i)]
    
    def E_(self, int i):
        return self.doc[self.c.E(i)]

    def L_(self, int i, int idx):
        return self.doc[self.c.L(i, idx)]

    def R_(self, int i, int idx):
        return self.doc[self.c.R(i, idx)]
 
    def empty(self):
        return self.c.empty()

    def eol(self):
        return self.c.eol()

    def at_break(self):
        return False
        #return self.c.at_break()

    def has_head(self, int i):
        return self.c.has_head(i)

    def  n_L(self, int i):
        return self.c.n_L(i)

    def n_R(self, int i):
        return self.c.n_R(i)

    def entity_is_open(self):
        return self.c.entity_is_open()

    def stack_depth(self):
        return self.c.stack_depth()

    def buffer_length(self):
        return self.c.buffer_length()

    def push(self):
        self.c.push()

    def pop(self):
        self.c.pop()

    def unshift(self):
        self.c.unshift()

    def add_arc(self, int head, int child, attr_t label):
        self.c.add_arc(head, child, label)

    def del_arc(self, int head, int child):
        self.c.del_arc(head, child)

    def open_ent(self, attr_t label):
        self.c.open_ent(label)

    def close_ent(self):
        self.c.close_ent()

    def clone(self, StateClass src):
        self.c.clone(src.c)