spaCy/spacy/pipeline/_parser_internals/arc_eager.pyx
2021-01-12 17:28:41 +01:00

861 lines
28 KiB
Cython

# cython: profile=True, cdivision=True, infer_types=True
from cymem.cymem cimport Pool, Address
from libc.stdint cimport int32_t
from libcpp.vector cimport vector
from collections import defaultdict, Counter
from ...typedefs cimport hash_t, attr_t
from ...strings cimport hash_string
from ...structs cimport TokenC
from ...tokens.doc cimport Doc, set_children_from_heads
from ...tokens.token import MISSING_DEP_
from ...training.example cimport Example
from .stateclass cimport StateClass
from ._state cimport StateC, ArcC
from ...errors import Errors
from thinc.extra.search cimport Beam
cdef weight_t MIN_SCORE = -90000
cdef attr_t SUBTOK_LABEL = hash_string(u'subtok')
DEF NON_MONOTONIC = True
cdef enum:
SHIFT
REDUCE
LEFT
RIGHT
BREAK
N_MOVES
MOVE_NAMES = [None] * N_MOVES
MOVE_NAMES[SHIFT] = 'S'
MOVE_NAMES[REDUCE] = 'D'
MOVE_NAMES[LEFT] = 'L'
MOVE_NAMES[RIGHT] = 'R'
MOVE_NAMES[BREAK] = 'B'
cdef enum:
HEAD_IN_STACK = 0
HEAD_IN_BUFFER
HEAD_UNKNOWN
IS_SENT_START
SENT_START_UNKNOWN
cdef struct GoldParseStateC:
char* state_bits
int32_t* n_kids_in_buffer
int32_t* n_kids_in_stack
int32_t* heads
attr_t* labels
int32_t** kids
int32_t* n_kids
int32_t length
int32_t stride
weight_t push_cost
weight_t pop_cost
cdef GoldParseStateC create_gold_state(Pool mem, const StateC* state,
heads, labels, sent_starts) except *:
cdef GoldParseStateC gs
gs.length = len(heads)
gs.stride = 1
assert gs.length > 0
gs.labels = <attr_t*>mem.alloc(gs.length, sizeof(gs.labels[0]))
gs.heads = <int32_t*>mem.alloc(gs.length, sizeof(gs.heads[0]))
gs.n_kids = <int32_t*>mem.alloc(gs.length, sizeof(gs.n_kids[0]))
gs.state_bits = <char*>mem.alloc(gs.length, sizeof(gs.state_bits[0]))
gs.n_kids_in_buffer = <int32_t*>mem.alloc(gs.length, sizeof(gs.n_kids_in_buffer[0]))
gs.n_kids_in_stack = <int32_t*>mem.alloc(gs.length, sizeof(gs.n_kids_in_stack[0]))
for i, is_sent_start in enumerate(sent_starts):
if is_sent_start == True:
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
IS_SENT_START,
1
)
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
SENT_START_UNKNOWN,
0
)
elif is_sent_start is None:
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
SENT_START_UNKNOWN,
1
)
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
IS_SENT_START,
0
)
else:
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
SENT_START_UNKNOWN,
0
)
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
IS_SENT_START,
0
)
for i, (head, label) in enumerate(zip(heads, labels)):
if head is not None:
gs.heads[i] = head
gs.labels[i] = label
if i != head:
gs.n_kids[head] += 1
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
HEAD_UNKNOWN,
0
)
else:
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
HEAD_UNKNOWN,
1
)
# Make an array of pointers, pointing into the gs_kids_flat array.
assert gs.length > 0
gs.kids = <int32_t**>mem.alloc(gs.length, sizeof(int32_t*))
for i in range(gs.length):
if gs.n_kids[i] != 0:
gs.kids[i] = <int32_t*>mem.alloc(gs.n_kids[i], sizeof(int32_t))
# This is a temporary buffer
js_addr = Address(gs.length, sizeof(int32_t))
js = <int32_t*>js_addr.ptr
for i in range(gs.length):
if not is_head_unknown(&gs, i):
head = gs.heads[i]
if head != i:
gs.kids[head][js[head]] = i
js[head] += 1
gs.push_cost = push_cost(state, &gs)
gs.pop_cost = pop_cost(state, &gs)
return gs
cdef void update_gold_state(GoldParseStateC* gs, const StateC* s) nogil:
for i in range(gs.length):
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
HEAD_IN_BUFFER,
0
)
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
HEAD_IN_STACK,
0
)
gs.n_kids_in_stack[i] = 0
gs.n_kids_in_buffer[i] = 0
for i in range(s.stack_depth()):
s_i = s.S(i)
if not is_head_unknown(gs, s_i) and gs.heads[s_i] != s_i:
gs.n_kids_in_stack[gs.heads[s_i]] += 1
for kid in gs.kids[s_i][:gs.n_kids[s_i]]:
gs.state_bits[kid] = set_state_flag(
gs.state_bits[kid],
HEAD_IN_STACK,
1
)
for i in range(s.buffer_length()):
b_i = s.B(i)
if s.is_sent_start(b_i):
break
if not is_head_unknown(gs, b_i) and gs.heads[b_i] != b_i:
gs.n_kids_in_buffer[gs.heads[b_i]] += 1
for kid in gs.kids[b_i][:gs.n_kids[b_i]]:
gs.state_bits[kid] = set_state_flag(
gs.state_bits[kid],
HEAD_IN_BUFFER,
1
)
gs.push_cost = push_cost(s, gs)
gs.pop_cost = pop_cost(s, gs)
cdef class ArcEagerGold:
cdef GoldParseStateC c
cdef Pool mem
def __init__(self, ArcEager moves, StateClass stcls, Example example):
self.mem = Pool()
heads, labels = example.get_aligned_parse(projectivize=True)
labels = [label if label is not None else MISSING_DEP_ for label in labels]
labels = [example.x.vocab.strings.add(label) for label in labels]
sent_starts = example.get_aligned_sent_starts()
assert len(heads) == len(labels) == len(sent_starts), (len(heads), len(labels), len(sent_starts))
self.c = create_gold_state(self.mem, stcls.c, heads, labels, sent_starts)
def update(self, StateClass stcls):
update_gold_state(&self.c, stcls.c)
cdef int check_state_gold(char state_bits, char flag) nogil:
cdef char one = 1
return 1 if (state_bits & (one << flag)) else 0
cdef int set_state_flag(char state_bits, char flag, int value) nogil:
cdef char one = 1
if value:
return state_bits | (one << flag)
else:
return state_bits & ~(one << flag)
cdef int is_head_in_stack(const GoldParseStateC* gold, int i) nogil:
return check_state_gold(gold.state_bits[i], HEAD_IN_STACK)
cdef int is_head_in_buffer(const GoldParseStateC* gold, int i) nogil:
return check_state_gold(gold.state_bits[i], HEAD_IN_BUFFER)
cdef int is_head_unknown(const GoldParseStateC* gold, int i) nogil:
return check_state_gold(gold.state_bits[i], HEAD_UNKNOWN)
cdef int is_sent_start(const GoldParseStateC* gold, int i) nogil:
return check_state_gold(gold.state_bits[i], IS_SENT_START)
cdef int is_sent_start_unknown(const GoldParseStateC* gold, int i) nogil:
return check_state_gold(gold.state_bits[i], SENT_START_UNKNOWN)
# Helper functions for the arc-eager oracle
cdef weight_t push_cost(const StateC* state, const GoldParseStateC* gold) nogil:
cdef weight_t cost = 0
b0 = state.B(0)
if b0 < 0:
return 9000
if is_head_in_stack(gold, b0):
cost += 1
cost += gold.n_kids_in_stack[b0]
if Break.is_valid(state, 0) and is_sent_start(gold, state.B(1)):
cost += 1
return cost
cdef weight_t pop_cost(const StateC* state, const GoldParseStateC* gold) nogil:
cdef weight_t cost = 0
s0 = state.S(0)
if s0 < 0:
return 9000
if is_head_in_buffer(gold, s0):
cost += 1
cost += gold.n_kids_in_buffer[s0]
return cost
cdef bint arc_is_gold(const GoldParseStateC* gold, int head, int child) nogil:
if is_head_unknown(gold, child):
return True
elif gold.heads[child] == head:
return True
else:
return False
cdef bint label_is_gold(const GoldParseStateC* gold, int child, attr_t label) nogil:
if is_head_unknown(gold, child):
return True
elif label == 0:
return True
elif gold.labels[child] == label:
return True
else:
return False
cdef bint _is_gold_root(const GoldParseStateC* gold, int word) nogil:
return gold.heads[word] == word or is_head_unknown(gold, word)
cdef class Shift:
"""Move the first word of the buffer onto the stack and mark it as "shifted"
Validity:
* If stack is empty
* At least two words in sentence
* Word has not been shifted before
Cost: push_cost
Action:
* Mark B[0] as 'shifted'
* Push stack
* Advance buffer
"""
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
if st.stack_depth() == 0:
return 1
elif st.buffer_length() < 2:
return 0
elif st.is_sent_start(st.B(0)):
return 0
elif st.is_unshiftable(st.B(0)):
return 0
else:
return 1
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
st.push()
@staticmethod
cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
gold = <const GoldParseStateC*>_gold
return gold.push_cost
cdef class Reduce:
"""
Pop from the stack. If it has no head and the stack isn't empty, place
it back on the buffer.
Validity:
* Stack not empty
* Buffer nt empty
* Stack depth 1 and cannot sent start l_edge(st.B(0))
Cost:
* If B[0] is the start of a sentence, cost is 0
* Arcs between stack and buffer
* If arc has no head, we're saving arcs between S[0] and S[1:], so decrement
cost by those arcs.
"""
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
if st.stack_depth() == 0:
return False
elif st.buffer_length() == 0:
return True
elif st.stack_depth() == 1 and st.cannot_sent_start(st.l_edge(st.B(0))):
return False
else:
return True
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
if st.has_head(st.S(0)) or st.stack_depth() == 1:
st.pop()
else:
st.unshift()
@staticmethod
cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
gold = <const GoldParseStateC*>_gold
if state.is_sent_start(state.B(0)):
return 0
s0 = state.S(0)
cost = gold.pop_cost
if not state.has_head(s0):
# Decrement cost for the arcs we save, as we'll be putting this
# back to the buffer
if is_head_in_stack(gold, s0):
cost -= 1
cost -= gold.n_kids_in_stack[s0]
return cost
cdef class LeftArc:
"""Add an arc between B[0] and S[0], replacing the previous head of S[0] if
one is set. Pop S[0] from the stack.
Validity:
* len(S) >= 1
* len(B) >= 1
* not is_sent_start(B[0])
Cost:
pop_cost - Arc(B[0], S[0], label) + (Arc(S[1], S[0]) if H(S[0]) else Arcs(S, S[0]))
"""
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
if st.stack_depth() == 0:
return 0
elif st.buffer_length() == 0:
return 0
elif st.is_sent_start(st.B(0)):
return 0
elif label == SUBTOK_LABEL and st.S(0) != (st.B(0)-1):
return 0
else:
return 1
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
st.add_arc(st.B(0), st.S(0), label)
# If we change the stack, it's okay to remove the shifted mark, as
# we can't get in an infinite loop this way.
st.set_reshiftable(st.B(0))
st.pop()
@staticmethod
cdef inline weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
gold = <const GoldParseStateC*>_gold
cdef weight_t cost = gold.pop_cost
s0 = state.S(0)
s1 = state.S(1)
b0 = state.B(0)
if state.has_head(s0):
# Increment cost if we're clobbering a correct arc
cost += gold.heads[s0] == s1
else:
# If there's no head, we're losing arcs between S0 and S[1:].
cost += is_head_in_stack(gold, s0)
cost += gold.n_kids_in_stack[s0]
if b0 != -1 and s0 != -1 and gold.heads[s0] == b0:
cost -= 1
cost += not label_is_gold(gold, s0, label)
return cost
cdef class RightArc:
"""
Add an arc from S[0] to B[0]. Push B[0].
Validity:
* len(S) >= 1
* len(B) >= 1
* not is_sent_start(B[0])
Cost:
push_cost + (not shifted[b0] and Arc(B[1:], B[0])) - Arc(S[0], B[0], label)
"""
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
if st.stack_depth() == 0:
return 0
elif st.buffer_length() == 0:
return 0
elif st.is_sent_start(st.B(0)):
return 0
elif label == SUBTOK_LABEL and st.S(0) != (st.B(0)-1):
# If there's (perhaps partial) parse pre-set, don't allow cycle.
return 0
else:
return 1
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
st.add_arc(st.S(0), st.B(0), label)
st.push()
@staticmethod
cdef inline weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
gold = <const GoldParseStateC*>_gold
cost = gold.push_cost
s0 = state.S(0)
b0 = state.B(0)
if s0 != -1 and b0 != -1 and gold.heads[b0] == s0:
cost -= 1
cost += not label_is_gold(gold, b0, label)
elif is_head_in_buffer(gold, b0) and not state.is_unshiftable(b0):
cost += 1
return cost
cdef class Break:
"""Mark the second word of the buffer as the start of a
sentence.
Validity:
* len(buffer) >= 2
* B[1] == B[0] + 1
* not is_sent_start(B[1])
* not cannot_sent_start(B[1])
Action:
* mark_sent_start(B[1])
Cost:
* not is_sent_start(B[1])
* Arcs between B[0] and B[1:]
* Arcs between S and B[1]
"""
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
cdef int i
if st.buffer_length() < 2:
return False
elif st.B(1) != st.B(0) + 1:
return False
elif st.is_sent_start(st.B(1)):
return False
elif st.cannot_sent_start(st.B(1)):
return False
else:
return True
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
st.set_sent_start(st.B(1), 1)
@staticmethod
cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
gold = <const GoldParseStateC*>_gold
cdef int b0 = state.B(0)
cdef int cost = 0
cdef int si
for i in range(state.stack_depth()):
si = state.S(i)
if is_head_in_buffer(gold, si):
cost += 1
cost += gold.n_kids_in_buffer[si]
# We need to score into B[1:], so subtract deps that are at b0
if gold.heads[b0] == si:
cost -= 1
if gold.heads[si] == b0:
cost -= 1
if not is_sent_start(gold, state.B(1)) \
and not is_sent_start_unknown(gold, state.B(1)):
cost += 1
return cost
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
st = new StateC(<const TokenC*>tokens, length)
return <void*>st
cdef int _del_state(Pool mem, void* state, void* x) except -1:
cdef StateC* st = <StateC*>state
del st
cdef class ArcEager(TransitionSystem):
def __init__(self, *args, **kwargs):
TransitionSystem.__init__(self, *args, **kwargs)
self.init_beam_state = _init_state
self.del_beam_state = _del_state
@classmethod
def get_actions(cls, **kwargs):
min_freq = kwargs.get('min_freq', None)
actions = defaultdict(lambda: Counter())
actions[SHIFT][''] = 1
actions[REDUCE][''] = 1
for label in kwargs.get('left_labels', []):
actions[LEFT][label] = 1
actions[SHIFT][label] = 1
for label in kwargs.get('right_labels', []):
actions[RIGHT][label] = 1
actions[REDUCE][label] = 1
for example in kwargs.get('examples', []):
heads, labels = example.get_aligned_parse(projectivize=True)
for child, (head, label) in enumerate(zip(heads, labels)):
if head is None or label is None:
continue
if label.upper() == 'ROOT' :
label = 'ROOT'
if head == child:
actions[BREAK][label] += 1
if head < child:
actions[RIGHT][label] += 1
actions[REDUCE][''] += 1
elif head > child:
actions[LEFT][label] += 1
actions[SHIFT][''] += 1
if min_freq is not None:
for action, label_freqs in actions.items():
for label, freq in list(label_freqs.items()):
if freq < min_freq:
label_freqs.pop(label)
# Ensure these actions are present
actions[BREAK].setdefault('ROOT', 0)
if kwargs.get("learn_tokens") is True:
actions[RIGHT].setdefault('subtok', 0)
actions[LEFT].setdefault('subtok', 0)
# Used for backoff
actions[RIGHT].setdefault('dep', 0)
actions[LEFT].setdefault('dep', 0)
return actions
@property
def action_types(self):
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
def transition(self, StateClass state, action):
cdef Transition t = self.lookup_transition(action)
t.do(state.c, t.label)
return state
def is_gold_parse(self, StateClass state, ArcEagerGold gold):
for i in range(state.c.length):
token = state.c.safe_get(i)
if not arc_is_gold(&gold.c, i, i+token.head):
return False
elif not label_is_gold(&gold.c, i, token.dep):
return False
return True
def init_gold(self, StateClass state, Example example):
gold = ArcEagerGold(self, state, example)
self._replace_unseen_labels(gold)
return gold
def init_gold_batch(self, examples):
# TODO: Projectivity?
all_states = self.init_batch([eg.predicted for eg in examples])
golds = []
states = []
for state, eg in zip(all_states, examples):
if self.has_gold(eg) and not state.is_final():
golds.append(self.init_gold(state, eg))
states.append(state)
n_steps = sum([len(s.queue) for s in states])
return states, golds, n_steps
def _replace_unseen_labels(self, ArcEagerGold gold):
backoff_label = self.strings["dep"]
root_label = self.strings["ROOT"]
left_labels = self.labels[LEFT]
right_labels = self.labels[RIGHT]
break_labels = self.labels[BREAK]
for i in range(gold.c.length):
if not is_head_unknown(&gold.c, i):
head = gold.c.heads[i]
label = self.strings[gold.c.labels[i]]
if head > i and label not in left_labels:
gold.c.labels[i] = backoff_label
elif head < i and label not in right_labels:
gold.c.labels[i] = backoff_label
elif head == i and label not in break_labels:
gold.c.labels[i] = root_label
return gold
cdef Transition lookup_transition(self, object name_or_id) except *:
if isinstance(name_or_id, int):
return self.c[name_or_id]
name = name_or_id
if '-' in name:
move_str, label_str = name.split('-', 1)
label = self.strings[label_str]
else:
move_str = name
label = 0
move = MOVE_NAMES.index(move_str)
for i in range(self.n_moves):
if self.c[i].move == move and self.c[i].label == label:
return self.c[i]
raise KeyError(f"Unknown transition: {name}")
def move_name(self, int move, attr_t label):
label_str = self.strings[label]
if label_str:
return MOVE_NAMES[move] + '-' + label_str
else:
return MOVE_NAMES[move]
def class_name(self, int i):
return self.move_name(self.c[i].move, self.c[i].label)
cdef Transition init_transition(self, int clas, int move, attr_t label) except *:
# TODO: Apparent Cython bug here when we try to use the Transition()
# constructor with the function pointers
cdef Transition t
t.score = 0
t.clas = clas
t.move = move
t.label = label
if move == SHIFT:
t.is_valid = Shift.is_valid
t.do = Shift.transition
t.get_cost = Shift.cost
elif move == REDUCE:
t.is_valid = Reduce.is_valid
t.do = Reduce.transition
t.get_cost = Reduce.cost
elif move == LEFT:
t.is_valid = LeftArc.is_valid
t.do = LeftArc.transition
t.get_cost = LeftArc.cost
elif move == RIGHT:
t.is_valid = RightArc.is_valid
t.do = RightArc.transition
t.get_cost = RightArc.cost
elif move == BREAK:
t.is_valid = Break.is_valid
t.do = Break.transition
t.get_cost = Break.cost
else:
raise ValueError(Errors.E019.format(action=move, src='arc_eager'))
return t
def set_annotations(self, StateClass state, Doc doc):
for arc in state.arcs:
doc.c[arc["child"]].head = arc["head"] - arc["child"]
doc.c[arc["child"]].dep = arc["label"]
for i in range(doc.length):
if doc.c[i].head == 0:
doc.c[i].dep = self.root_label
set_children_from_heads(doc.c, 0, doc.length)
def get_beam_parses(self, Beam beam):
parses = []
probs = beam.probs
for i in range(beam.size):
state = <StateC*>beam.at(i)
if state.is_final():
prob = probs[i]
parse = []
arcs = self.get_arcs(state)
if arcs:
for arc in arcs:
dep = arc["label"]
label = self.strings[dep]
parse.append((arc["head"], arc["child"], label))
parses.append((prob, parse))
return parses
cdef get_arcs(self, StateC* state):
cdef vector[ArcC] arcs
state.get_arcs(&arcs)
return list(arcs)
def has_gold(self, Example eg, start=0, end=None):
for word in eg.y[start:end]:
if word.dep != 0:
return True
else:
return False
cdef int set_valid(self, int* output, const StateC* st) nogil:
cdef int[N_MOVES] is_valid
is_valid[SHIFT] = Shift.is_valid(st, 0)
is_valid[REDUCE] = Reduce.is_valid(st, 0)
is_valid[LEFT] = LeftArc.is_valid(st, 0)
is_valid[RIGHT] = RightArc.is_valid(st, 0)
is_valid[BREAK] = Break.is_valid(st, 0)
cdef int i
for i in range(self.n_moves):
if self.c[i].label == SUBTOK_LABEL:
output[i] = self.c[i].is_valid(st, self.c[i].label)
else:
output[i] = is_valid[self.c[i].move]
def get_cost(self, StateClass stcls, gold, int i):
if not isinstance(gold, ArcEagerGold):
raise TypeError(Errors.E909.format(name="ArcEagerGold"))
cdef ArcEagerGold gold_ = gold
gold_state = gold_.c
n_gold = 0
if self.c[i].is_valid(stcls.c, self.c[i].label):
cost = self.c[i].get_cost(stcls.c, &gold_state, self.c[i].label)
else:
cost = 9000
return cost
cdef int set_costs(self, int* is_valid, weight_t* costs,
const StateC* state, gold) except -1:
if not isinstance(gold, ArcEagerGold):
raise TypeError(Errors.E909.format(name="ArcEagerGold"))
cdef ArcEagerGold gold_ = gold
gold_state = gold_.c
update_gold_state(&gold_state, state)
self.set_valid(is_valid, state)
cdef int n_gold = 0
for i in range(self.n_moves):
if is_valid[i]:
costs[i] = self.c[i].get_cost(state, &gold_state, self.c[i].label)
if costs[i] <= 0:
n_gold += 1
else:
costs[i] = 9000
if n_gold < 1:
for i in range(self.n_moves):
print(self.get_class_name(i), is_valid[i], costs[i])
print("Gold sent starts?", is_sent_start(&gold_state, state.B(0)), is_sent_start(&gold_state, state.B(1)))
raise ValueError("Could not find gold transition - see logs above.")
def get_oracle_sequence_from_state(self, StateClass state, ArcEagerGold gold, _debug=None):
cdef int i
cdef Pool mem = Pool()
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
assert self.n_moves > 0
costs = <float*>mem.alloc(self.n_moves, sizeof(float))
is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))
history = []
debug_log = []
failed = False
while not state.is_final():
try:
self.set_costs(is_valid, costs, state.c, gold)
except ValueError:
failed = True
break
min_cost = min(costs[i] for i in range(self.n_moves))
for i in range(self.n_moves):
if is_valid[i] and costs[i] <= min_cost:
action = self.c[i]
history.append(i)
s0 = state.S(0)
b0 = state.B(0)
if _debug:
example = _debug
debug_log.append(" ".join((
self.get_class_name(i),
state.print_state()
)))
action.do(state.c, action.label)
break
else:
failed = False
break
if failed:
example = _debug
print("Actions")
for i in range(self.n_moves):
print(self.get_class_name(i))
print("Gold")
for token in example.y:
print(token.i, token.text, token.dep_, token.head.text)
aligned_heads, aligned_labels = example.get_aligned_parse()
print("Aligned heads")
for i, head in enumerate(aligned_heads):
print(example.x[i], example.x[head] if head is not None else "__")
print("Aligned sent starts")
print(example.get_aligned_sent_starts())
print("Predicted tokens")
print([(w.i, w.text) for w in example.x])
s0 = state.S(0)
b0 = state.B(0)
debug_log.append(" ".join((
"?",
"S0=", (example.x[s0].text if s0 >= 0 else "-"),
"B0=", (example.x[b0].text if b0 >= 0 else "-"),
"S0 head?", str(state.has_head(state.S(0))),
)))
s0 = state.S(0)
b0 = state.B(0)
print("\n".join(debug_log))
print("Arc is gold B0, S0?", arc_is_gold(&gold.c, b0, s0))
print("Arc is gold S0, B0?", arc_is_gold(&gold.c, s0, b0))
print("is_head_unknown(s0)", is_head_unknown(&gold.c, s0))
print("is_head_unknown(b0)", is_head_unknown(&gold.c, b0))
print("b0", b0, "gold.heads[s0]", gold.c.heads[s0])
print("Stack", [example.x[i] for i in state.stack])
print("Buffer", [example.x[i] for i in state.queue])
raise ValueError(Errors.E024)
return history