* Refactored transition system code now compiling. Still need to hook up label oracle, and test

This commit is contained in:
Matthew Honnibal 2015-02-22 00:32:07 -05:00
parent 6e86790a4e
commit 8c883cef58
7 changed files with 109 additions and 63 deletions

View File

@ -112,18 +112,17 @@ cdef int count_right_kids(const TokenC* head) nogil:
return _popcount(head.r_kids)
cdef State* init_state(Pool mem, TokenC* sent, const int sent_length) except NULL:
cdef int padded_len = sent_length + PADDING + PADDING
cdef State* init_state(Pool mem, const TokenC* sent, const int sent_len) except NULL:
cdef int padded_len = sent_len + PADDING + PADDING
cdef State* s = <State*>mem.alloc(1, sizeof(State))
s.stack = <int*>mem.alloc(padded_len, sizeof(int))
for i in range(PADDING):
s.stack[i] = -1
s.stack += (PADDING - 1)
assert s.stack[0] == -1
s.sent = sent
s.sent = <TokenC*>mem.alloc(sent_len, sizeof(TokenC))
s.stack_len = 0
s.i = 0
s.sent_len = sent_length
s.sent_len = sent_len
push_stack(s)
return s

View File

@ -33,14 +33,28 @@ cdef get_cost_func_t[N_MOVES] get_cost_funcs
cdef class ArcEager(TransitionSystem):
@classmethod
def get_labels(cls, gold_parses):
labels = {RIGHT: {}, LEFT: {}}
for parse in gold_parses:
for i, (head, label) in enumerate(zip(parse.heads, parse.labels)):
if head > i:
labels[RIGHT][label] = True
else:
labels[LEFT][label] = True
return labels
cdef Transition init_transition(self, int clas, int move, int label) except *:
return Transition(
score=0,
clas=i,
move=move,
label=label,
do=do_funcs[move],
get_cost=get_cost_funcs[move])
# TODO: Apparent Cython bug here when we try to use the Transition()
# constructor with the function pointers
cdef Transition t
t.score = 0
t.clas = clas
t.move = move
t.label = label
t.do = do_funcs[move]
t.get_cost = get_cost_funcs[move]
return t
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
cdef bint[N_MOVES] is_valid
@ -111,8 +125,8 @@ do_funcs[BREAK] = _do_break
cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) except -1:
assert not at_eol(s)
cost = 0
cost += head_in_stack(s, s.i, gold.heads)
cost += children_in_stack(s, s.i, gold.heads)
cost += head_in_stack(s, s.i, gold.c_heads)
cost += children_in_stack(s, s.i, gold.c_heads)
if NON_MONOTONIC:
cost += gold[s.stack[0]] == s.i
# If we can break, and there's no cost to doing so, we should
@ -126,9 +140,9 @@ cdef int _right_cost(const Transition* self, const State* s, GoldParse gold) exc
cost = 0
if gold[s.i] == s.stack[0]:
return cost
cost += head_in_buffer(s, s.i, gold.heads)
cost += children_in_stack(s, s.i, gold.heads)
cost += head_in_stack(s, s.i, gold.heads)
cost += head_in_buffer(s, s.i, gold.c_heads)
cost += children_in_stack(s, s.i, gold.c_heads)
cost += head_in_stack(s, s.i, gold.c_heads)
if NON_MONOTONIC:
cost += gold[s.stack[0]] == s.i
return cost
@ -140,8 +154,8 @@ cdef int _left_cost(const Transition* self, const State* s, GoldParse gold) exce
if gold[s.stack[0]] == s.i:
return cost
cost += head_in_buffer(s, s.stack[0], gold.heads)
cost += children_in_buffer(s, s.stack[0], gold.heads)
cost += head_in_buffer(s, s.stack[0], gold.c_heads)
cost += children_in_buffer(s, s.stack[0], gold.c_heads)
if NON_MONOTONIC and s.stack_len >= 2:
cost += gold[s.stack[0]] == s.stack[-1]
cost += gold[s.stack[0]] == s.stack[0]
@ -150,9 +164,9 @@ cdef int _left_cost(const Transition* self, const State* s, GoldParse gold) exce
cdef int _reduce_cost(const Transition* self, const State* s, GoldParse gold) except -1:
cdef int cost = 0
cost += children_in_buffer(s, s.stack[0], gold.heads)
cost += children_in_buffer(s, s.stack[0], gold.c_heads)
if NON_MONOTONIC:
cost += head_in_buffer(s, s.stack[0], gold.heads)
cost += head_in_buffer(s, s.stack[0], gold.c_heads)
return cost
@ -161,8 +175,8 @@ cdef int _break_cost(const Transition* self, const State* s, GoldParse gold) exc
cdef int cost = 0
# Number of deps between S0...Sn and N0...Nn
for i in range(s.i, s.sent_len):
cost += children_in_stack(s, i, gold.heads)
cost += head_in_stack(s, i, gold.heads)
cost += children_in_stack(s, i, gold.c_heads)
cost += head_in_stack(s, i, gold.c_heads)
return cost

View File

@ -1,7 +1,22 @@
from cymem.cymem cimport Pool
from ..structs cimport TokenC
cdef class GoldParse:
cdef int* heads
cdef int* labels
cdef Pool mem
cdef int* c_heads
cdef int* c_labels
cdef int length
cdef int loss
cdef unicode raw_text
cdef list words
cdef list ids
cdef list tags
cdef list heads
cdef list labels
cdef int heads_correct(self, TokenC* tokens, bint score_punct=?) except -1

View File

@ -1,11 +1,20 @@
cdef class GoldParse:
def __init__(self):
pass
def __init__(self, raw_text, words, ids, tags, heads, labels):
self.mem = Pool()
self.loss = 0
self.length = len(words)
self.raw_text = raw_text
self.words = words
self.ids = ids
self.tags = tags
self.heads = heads
self.labels = labels
self.c_heads = <int*>self.mem.alloc(self.length, sizeof(int))
self.c_labels = <int*>self.mem.alloc(self.length, sizeof(int))
cdef int heads_correct(self, TokenC* tokens, bint score_punct=False) except -1:
pass
"""
@classmethod
def from_conll(cls, unicode sent_str):
ids = []
@ -50,42 +59,44 @@ cdef class GoldParse:
for sent_str in tok_text.split('<SENT>')]
return cls(raw_text, tokenized, ids, words, tags, heads, labels)
cdef int heads_correct(self, TokenC* tokens, bint score_punct=False) except -1:
pass
def align_to_non_gold_tokens(self, tokens):
# TODO
tags = []
heads = []
labels = []
orig_words = list(words)
def align_to_tokens(self, tokens, label_ids):
orig_words = list(self.words)
annot = zip(self.ids, self.tags, self.heads, self.labels)
self.ids = []
self.tags = []
self.heads = []
self.labels = []
missed = []
for token in tokens:
while annot and token.idx > annot[0][0]:
miss_id, miss_tag, miss_head, miss_label = annot.pop(0)
miss_w = words.pop(0)
miss_w = self.words.pop(0)
if not is_punct_label(miss_label):
missed.append(miss_w)
loss += 1
self.loss += 1
if not annot:
tags.append(None)
heads.append(None)
labels.append(None)
self.tags.append(None)
self.heads.append(None)
self.labels.append(None)
continue
id_, tag, head, label = annot[0]
if token.idx == id_:
tags.append(tag)
heads.append(head)
labels.append(label)
self.tags.append(tag)
self.heads.append(head)
self.labels.append(label)
annot.pop(0)
words.pop(0)
self.words.pop(0)
elif token.idx < id_:
tags.append(None)
heads.append(None)
labels.append(None)
self.tags.append(None)
self.heads.append(None)
self.labels.append(None)
else:
raise StandardError
return loss, tags, heads, labels
mapped_heads = _map_indices_to_tokens(self.ids, self.heads)
for i in range(self.length):
self.c_heads[i] = mapped_heads[i]
self.c_labels[i] = label_ids[self.labels[i]]
return self.loss
def is_punct_label(label):
@ -116,6 +127,7 @@ def _parse_line(line):
return id_, word, pos, head_idx, label
"""
# TODO
def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
global loss

View File

@ -26,9 +26,10 @@ from thinc.learner cimport LinearModel
from ..tokens cimport Tokens, TokenC
from .arc_eager cimport TransitionSystem, Transition
from .arc_eager import OracleError
from .transition_system import OracleError
from ._state cimport init_state, State, is_final, get_idx, get_s0, get_s1, get_n0, get_n1
from .conll cimport GoldParse
from . import _parse_features
from ._parse_features cimport fill_context, CONTEXT_SIZE
@ -60,10 +61,10 @@ def get_templates(name):
cdef class GreedyParser:
def __init__(self, model_dir):
def __init__(self, model_dir, transition_system):
assert os.path.exists(model_dir) and os.path.isdir(model_dir)
self.cfg = Config.read(model_dir, 'config')
self.moves = TransitionSystem(self.cfg.left_labels, self.cfg.right_labels)
self.moves = transition_system(self.cfg.labels)
templates = get_templates(self.cfg.features)
self.model = Model(self.moves.n_moves, templates, model_dir)
@ -74,23 +75,24 @@ cdef class GreedyParser:
cdef atom_t[CONTEXT_SIZE] context
cdef int n_feats
cdef Pool mem = Pool()
cdef State* state = init_state(mem, tokens.data, tokens.length) # TODO
cdef State* state = init_state(mem, tokens.data, tokens.length)
cdef Transition guess
while not is_final(state):
fill_context(context, state)
scores = self.model.score(context)
guess = self.moves.best_valid(scores, state)
guess.do(&guess, state)
tokens.set_parse(state.sent, self.moves.label_ids) # TODO
tokens.set_parse(state.sent, self.moves.label_ids)
return 0
def train_sent(self, Tokens tokens, GoldParse gold, force_gold=False):
def train(self, Tokens tokens, GoldParse gold, force_gold=False):
cdef:
int n_feats
int cost
const Feature* feats
const weight_t* scores
Transition guess
Transition gold
Transition best
atom_t[CONTEXT_SIZE] context
@ -101,7 +103,7 @@ cdef class GreedyParser:
fill_context(context, state)
scores = self.model.score(context)
guess = self.moves.best_valid(scores, state)
gold = self.moves.best_gold(scores, state, gold)
best = self.moves.best_gold(scores, state, gold)
cost = guess.get_cost(&guess, state, gold)
self.model.update(context, guess.clas, best.clas, cost)
if force_gold:

View File

@ -30,7 +30,7 @@ cdef class TransitionSystem:
cdef Transition init_transition(self, int clas, int move, int label) except *
cdef const Transition best_valid(self, const weight_t*, const State*) except *
cdef Transition best_valid(self, const weight_t* scores, const State* state) except *
cdef const Transition best_gold(self, const weight_t*, const State*,
GoldParse gold) except *
cdef Transition best_gold(self, const weight_t* scores, const State* state,
GoldParse gold) except *

View File

@ -7,6 +7,10 @@ from thinc.typedefs cimport weight_t
cdef weight_t MIN_SCORE = -90000
class OracleError(Exception):
pass
cdef class TransitionSystem:
def __init__(self, dict labels_by_action):
self.mem = Pool()
@ -28,7 +32,7 @@ cdef class TransitionSystem:
raise NotImplementedError
cdef Transition best_gold(self, const weight_t* scores, const State* s,
const TokenC* gold) except *:
GoldParse gold) except *:
cdef Transition best
cdef weight_t score = MIN_SCORE
cdef int i