WIP on updating transition-system

This commit is contained in:
Matthew Honnibal 2020-06-14 17:22:14 +02:00
parent 7d65615625
commit 60d4e5a9e0
5 changed files with 132 additions and 198 deletions

View File

@ -16,7 +16,6 @@ from thinc.api import Linear, Model, CupyOps, NumpyOps, use_ops, noop
from ..typedefs cimport weight_t, class_t, hash_t
from ..tokens.doc cimport Doc
from ..gold cimport GoldParse
from .stateclass cimport StateClass
from .transition_system cimport Transition

View File

@ -8,7 +8,6 @@ import json
from ..typedefs cimport hash_t, attr_t
from ..strings cimport hash_string
from ..gold cimport GoldParse, GoldParseC
from ..structs cimport TokenC
from ..tokens.doc cimport Doc, set_children_from_heads
from .stateclass cimport StateClass
@ -49,40 +48,75 @@ MOVE_NAMES[RIGHT] = 'R'
MOVE_NAMES[BREAK] = 'B'
cdef enum:
HEAD_ON_STACK = 0
HEAD_IN_BUFFER
IS_SENT_START
HEAD_UNKNOWN
cdef struct GoldParseStateC:
char* state_bits
attr_t* labels
int32_t* heads
int32_t* n_kids_in_buffer
int32_t* n_kids_on_stack
int32_t length
int32_t stride
cdef int check_state_flag(char state_bits, char flag) nogil:
cdef char one = 1
return state_bits & (one << flag)
cdef int set_state_flag(char state_bits, char flag, int value) nogil:
cdef char one = 1
if value:
return state_bits | (one << flag)
else:
return state_bits & ~(one << flag)
cdef int is_head_on_stack(GoldParseStateC gold, int i) nogil:
return check_state_gold(gold.state_bits[i], HEAD_ON_STACK)
cdef int is_head_in_buffer(GoldParseStateC gold, int i) nogil:
return check_state_gold(gold.state_bits[i], HEAD_IN_BUFFER)
cdef int is_sent_start(GoldParseStateC gold, int i) nogil:
return check_state_gold(gold.state_bits[i], IS_SENT_START)
cdef int is_head_unknown(GoldParseStateC gold, int i) nogil:
return check_state_gold(gold.state_bits[i], HEAD_UNKNOWN)
# Helper functions for the arc-eager oracle
cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
cdef weight_t push_cost(StateClass stcls, const GoldParseStateC* gold, int target) nogil:
cdef weight_t cost = 0
cdef int i, S_i
for i in range(stcls.stack_depth()):
S_i = stcls.S(i)
if gold.heads[target] == S_i:
if is_head_in_stack(gold[0], target):
cost += 1
if gold.heads[S_i] == target and (NON_MONOTONIC or not stcls.has_head(S_i)):
cost += 1
if BINARY_COSTS and cost >= 1:
return cost
cost += Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0
return cost
cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
cdef weight_t cost = 0
cdef int i, B_i
for i in range(stcls.buffer_length()):
B_i = stcls.B(i)
cost += gold.heads[B_i] == target
cost += gold.heads[target] == B_i
if gold.heads[B_i] == B_i or gold.heads[B_i] < target:
break
if BINARY_COSTS and cost >= 1:
return cost
cost += gold.n_kids_in_buffer[target]
if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0:
cost += 1
return cost
cdef weight_t arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) nogil:
cdef weight_t pop_cost(StateClass stcls, const GoldParseStateC* gold, int target) nogil:
cdef weight_t cost = 0
if is_head_in_buffer(gold[0], target):
cost += 1
cost += gold[0].n_kids_in_buffer[target]
if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0:
cost += 1
return cost
cdef weight_t arc_cost(StateClass stcls, const GoldParseStateC* gold, int head, int child) nogil:
if arc_is_gold(gold, head, child):
return 0
elif stcls.H(child) == gold.heads[child]:
@ -94,8 +128,8 @@ cdef weight_t arc_cost(StateClass stcls, const GoldParseC* gold, int head, int c
return 0
cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) nogil:
if not gold.has_dep[child]:
cdef bint arc_is_gold(const GoldParseStateC* gold, int head, int child) nogil:
if is_head_unknown(gold[0], child):
return True
elif gold.heads[child] == head:
return True
@ -103,8 +137,8 @@ cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) nogil:
return False
cdef bint label_is_gold(const GoldParseC* gold, int head, int child, attr_t label) nogil:
if not gold.has_dep[child]:
cdef bint label_is_gold(const GoldParseStateC* gold, int head, int child, attr_t label) nogil:
if is_head_unknown(gold[0], child):
return True
elif label == 0:
return True
@ -114,8 +148,9 @@ cdef bint label_is_gold(const GoldParseC* gold, int head, int child, attr_t labe
return False
cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil:
return gold.heads[word] == word or not gold.has_dep[word]
cdef bint _is_gold_root(const GoldParseStateC* gold, int word) nogil:
return gold.heads[word] == word or is_head_unknown(gold[0], word)
cdef class Shift:
@staticmethod
@ -129,15 +164,16 @@ cdef class Shift:
st.fast_forward()
@staticmethod
cdef weight_t cost(StateClass st, const GoldParseC* gold, attr_t label) nogil:
cdef weight_t cost(StateClass st, const void* _gold, attr_t label) nogil:
gold = <const GoldParseStateC*>_gold
return Shift.move_cost(st, gold) + Shift.label_cost(st, gold, label)
@staticmethod
cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
cdef inline weight_t move_cost(StateClass s, const GoldParseStateC* gold) nogil:
return push_cost(s, gold, s.B(0))
@staticmethod
cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
cdef inline weight_t label_cost(StateClass s, const GoldParseStateC* gold, attr_t label) nogil:
return 0
@ -155,26 +191,27 @@ cdef class Reduce:
st.fast_forward()
@staticmethod
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
gold = <const GoldParseStateC*>_gold
return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label)
@staticmethod
cdef inline weight_t move_cost(StateClass st, const GoldParseC* gold) nogil:
cost = pop_cost(st, gold, st.S(0))
if not st.has_head(st.S(0)):
# Decrement cost for the arcs e save
for i in range(1, st.stack_depth()):
S_i = st.S(i)
if gold.heads[st.S(0)] == S_i:
cost -= 1
if gold.heads[S_i] == st.S(0):
cdef inline weight_t move_cost(StateClass st, const GoldParseStateC* gold) nogil:
s0 = st.S(0)
cost = pop_cost(st, gold, s0)
return_to_buffer = not st.has_head(s0)
if return_to_buffer:
# Decrement cost for the arcs we save, as we'll be putting this
# back to the buffer
if is_head_in_stack(gold[0], s0):
cost -= 1
cost -= gold.n_kids_in_stack[s0]
if Break.is_valid(st.c, 0) and Break.move_cost(st, gold) == 0:
cost -= 1
return cost
@staticmethod
cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
cdef inline weight_t label_cost(StateClass s, const GoldParseStateC* gold, attr_t label) nogil:
return 0
@ -193,49 +230,12 @@ cdef class LeftArc:
st.fast_forward()
@staticmethod
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label)
@staticmethod
cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
cdef weight_t cost = 0
if arc_is_gold(gold, s.B(0), s.S(0)):
# Have a negative cost if we 'recover' from the wrong dependency
return 0 if not s.has_head(s.S(0)) else -1
else:
# Account for deps we might lose between S0 and stack
if not s.has_head(s.S(0)):
for i in range(1, s.stack_depth()):
cost += gold.heads[s.S(i)] == s.S(0)
cost += gold.heads[s.S(0)] == s.S(i)
return cost + pop_cost(s, gold, s.S(0)) + arc_cost(s, gold, s.B(0), s.S(0))
@staticmethod
cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
return arc_is_gold(gold, s.B(0), s.S(0)) and not label_is_gold(gold, s.B(0), s.S(0), label)
cdef class RightArc:
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
# If there's (perhaps partial) parse pre-set, don't allow cycle.
if label == SUBTOK_LABEL and st.S(0) != (st.B(0)-1):
return 0
sent_start = st._sent[st.B_(0).l_edge].sent_start
return sent_start != 1 and st.H(st.S(0)) != st.B(0)
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
st.add_arc(st.S(0), st.B(0), label)
st.push()
st.fast_forward()
@staticmethod
cdef inline weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
cdef inline weight_t cost(StateClass s, const void* gold, attr_t label) nogil:
gold = <const GoldParseStateC*>_gold
return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
@staticmethod
cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
cdef inline weight_t move_cost(StateClass s, const GoldParseStateC* gold) nogil:
if arc_is_gold(gold, s.S(0), s.B(0)):
return 0
elif s.c.shifted[s.B(0)]:
@ -244,7 +244,7 @@ cdef class RightArc:
return push_cost(s, gold, s.B(0)) + arc_cost(s, gold, s.S(0), s.B(0))
@staticmethod
cdef weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
cdef weight_t label_cost(StateClass s, const GoldParseStateC* gold, attr_t label) nogil:
return arc_is_gold(gold, s.S(0), s.B(0)) and not label_is_gold(gold, s.S(0), s.B(0), label)
@ -271,21 +271,19 @@ cdef class Break:
st.fast_forward()
@staticmethod
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
gold = <const GoldParseStateC*>_gold
return Break.move_cost(s, gold) + Break.label_cost(s, gold, label)
@staticmethod
cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
cdef inline weight_t move_cost(StateClass s, const GoldParseStateC* gold) nogil:
cdef weight_t cost = 0
cdef int i, j, S_i, B_i
for i in range(s.stack_depth()):
S_i = s.S(i)
for j in range(s.buffer_length()):
B_i = s.B(j)
cost += gold.heads[S_i] == B_i
cost += gold.heads[B_i] == S_i
if cost != 0:
return cost
cost += gold.n_kids_in_buffer[S_i]
if is_head_in_buffer(gold[0], S_i):
cost += 1
# Check for sentence boundary --- if it's here, we can't have any deps
# between stack and buffer, so rest of action is irrelevant.
s0_root = _get_root(s.S(0), gold)
@ -296,13 +294,15 @@ cdef class Break:
return cost + 1
@staticmethod
cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
cdef inline weight_t label_cost(StateClass s, const GoldParseStateC* gold, attr_t label) nogil:
return 0
cdef int _get_root(int word, const GoldParseC* gold) nogil:
while gold.heads[word] != word and gold.has_dep[word] and word >= 0:
cdef int _get_root(int word, const GoldParseStateC* gold) nogil:
if is_head_unset(gold[0], word):
return -1
while gold.heads[word] != word and word >= 0:
word = gold.heads[word]
if not gold.has_dep[word]:
if is_head_unset(gold[0], word):
return -1
else:
return word
@ -378,86 +378,22 @@ cdef class ArcEager(TransitionSystem):
def action_types(self):
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
def get_cost(self, StateClass state, GoldParse gold, action):
cdef Transition t = self.lookup_transition(action)
if not t.is_valid(state.c, t.label):
return 9000
else:
return t.get_cost(state, &gold.c, t.label)
def get_cost(self, StateClass state, NewExample gold, action):
raise NotImplementedError
def transition(self, StateClass state, action):
cdef Transition t = self.lookup_transition(action)
t.do(state.c, t.label)
return state
def is_gold_parse(self, StateClass state, GoldParse gold):
predicted = set()
truth = set()
for i in range(gold.length):
if gold.cand_to_gold[i] is None:
continue
if state.safe_get(i).dep:
predicted.add((i, state.H(i),
self.strings[state.safe_get(i).dep]))
else:
predicted.add((i, state.H(i), 'ROOT'))
id_ = gold.orig.ids[gold.cand_to_gold[i]]
head = gold.orig.heads[gold.cand_to_gold[i]]
dep = gold.orig.deps[gold.cand_to_gold[i]]
truth.add((id_, head, dep))
return truth == predicted
def is_gold_parse(self, StateClass state, gold):
raise NotImplementedError
def has_gold(self, GoldParse gold, start=0, end=None):
end = end or len(gold.heads)
if all([tag is None for tag in gold.heads[start:end]]):
return False
else:
return True
def has_gold(self, gold, start=0, end=None):
raise NotImplementedError
def preprocess_gold(self, GoldParse gold):
if not self.has_gold(gold):
return None
# Figure out whether we're using subtok
use_subtok = False
for action, labels in self.labels.items():
if SUBTOK_LABEL in labels:
use_subtok = True
break
for i, (head, dep) in enumerate(zip(gold.heads, gold.labels)):
# Missing values
if head is None or dep is None:
gold.c.heads[i] = i
gold.c.has_dep[i] = False
elif dep == SUBTOK_LABEL and not use_subtok:
# If we're not doing the joint tokenization and parsing,
# regard these subtok labels as missing
gold.c.heads[i] = i
gold.c.labels[i] = 0
gold.c.has_dep[i] = False
else:
if head > i:
action = LEFT
elif head < i:
action = RIGHT
else:
action = BREAK
if dep not in self.labels[action]:
if action == BREAK:
dep = 'ROOT'
elif nonproj.is_decorated(dep):
backoff = nonproj.decompose(dep)[0]
if backoff in self.labels[action]:
dep = backoff
else:
dep = 'dep'
else:
dep = 'dep'
gold.c.has_dep[i] = True
if dep.upper() == 'ROOT':
dep = 'ROOT'
gold.c.heads[i] = head
gold.c.labels[i] = self.strings.add(dep)
return gold
def preprocess_gold(self, gold):
raise NotImplementedError
def get_beam_parses(self, Beam beam):
parses = []
@ -569,7 +505,9 @@ cdef class ArcEager(TransitionSystem):
output[i] = is_valid[self.c[i].move]
cdef int set_costs(self, int* is_valid, weight_t* costs,
StateClass stcls, GoldParse gold) except -1:
StateClass stcls, NewExample example) except -1:
cdef Pool mem = Pool()
gold_state = create_gold_state(mem, stcls, example)
cdef int i, move
cdef attr_t label
cdef label_cost_func_t[N_MOVES] label_cost_funcs
@ -599,8 +537,8 @@ cdef class ArcEager(TransitionSystem):
move = self.c[i].move
label = self.c[i].label
if move_costs[move] == 9000:
move_costs[move] = move_cost_funcs[move](stcls, &gold.c)
costs[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label)
move_costs[move] = move_cost_funcs[move](stcls, gold_state)
costs[i] = move_costs[move] + label_cost_funcs[move](stcls, gold_state, label)
n_gold += costs[i] <= 0
else:
is_valid[i] = False

View File

@ -1,6 +1,5 @@
from .transition_system cimport TransitionSystem
from .transition_system cimport Transition
from ..gold cimport GoldParseC
from ..typedefs cimport attr_t

View File

@ -7,11 +7,11 @@ from .stateclass cimport StateClass
from ._state cimport StateC
from .transition_system cimport Transition
from .transition_system cimport do_func_t
from ..gold cimport GoldParseC, GoldParse
from ..lexeme cimport Lexeme
from ..attrs cimport IS_SPACE
from ..errors import Errors
from .gold_parse cimport GoldParseC
cdef enum:
@ -91,19 +91,11 @@ cdef class BiluoPushDown(TransitionSystem):
else:
return MOVE_NAMES[move] + '-' + self.strings[label]
def has_gold(self, GoldParse gold, start=0, end=None):
end = end or len(gold.ner)
if all([tag in ('-', None) for tag in gold.ner[start:end]]):
return False
else:
return True
def has_gold(self, gold, start=0, end=None):
raise NotImplementedError
def preprocess_gold(self, GoldParse gold):
if not self.has_gold(gold):
return None
for i in range(gold.length):
gold.c.ner[i] = self.lookup_transition(gold.ner[i])
return gold
def preprocess_gold(self, gold):
raise NotImplementedError
def get_beam_annot(self, Beam beam):
entities = {}
@ -248,7 +240,7 @@ cdef class Missing:
pass
@staticmethod
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
return 9000
@ -300,7 +292,8 @@ cdef class Begin:
st.pop()
@staticmethod
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
gold = <GoldParseC*>_gold
cdef int g_act = gold.ner[s.B(0)].move
cdef attr_t g_tag = gold.ner[s.B(0)].label
@ -363,7 +356,8 @@ cdef class In:
st.pop()
@staticmethod
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
gold = <GoldParseC*>_gold
move = IN
cdef int next_act = gold.ner[s.B(1)].move if s.B(1) >= 0 else OUT
cdef int g_act = gold.ner[s.B(0)].move
@ -429,7 +423,8 @@ cdef class Last:
st.pop()
@staticmethod
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
gold = <GoldParseC*>_gold
move = LAST
cdef int g_act = gold.ner[s.B(0)].move
@ -497,7 +492,8 @@ cdef class Unit:
st.pop()
@staticmethod
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
gold = <GoldParseC*>_gold
cdef int g_act = gold.ner[s.B(0)].move
cdef attr_t g_tag = gold.ner[s.B(0)].label
@ -537,7 +533,8 @@ cdef class Out:
st.pop()
@staticmethod
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
gold = <GoldParseC*>_gold
cdef int g_act = gold.ner[s.B(0)].move
cdef attr_t g_tag = gold.ner[s.B(0)].label

View File

@ -5,6 +5,7 @@ from ..structs cimport TokenC
from ..strings cimport StringStore
from .stateclass cimport StateClass
from ._state cimport StateC
from ..gold.new_example cimport NewExample
cdef struct Transition: