* Work on refactored parser, where TransitionSystem can be easily subclassed

This commit is contained in:
Matthew Honnibal 2015-02-20 23:30:31 -05:00
parent 1cc6329b18
commit dc986dbc0b

View File

@ -53,10 +53,10 @@ cdef unicode print_state(State* s, list words):
def get_templates(name): def get_templates(name):
pf = _parse_features pf = _parse_features
if name == 'zhang': if name == 'zhang':
return pf.unigrams, pf.arc_eager return pf.arc_eager
else: else:
return pf.unigrams, (pf.unigrams + pf.s0_n0 + pf.s1_n0 + pf.s0_n1 + pf.n0_n1 + \ return (pf.unigrams + pf.s0_n0 + pf.s1_n0 + pf.s0_n1 + pf.n0_n1 + \
pf.tree_shape + pf.trigrams) pf.tree_shape + pf.trigrams)
cdef class GreedyParser: cdef class GreedyParser:
@ -64,84 +64,51 @@ cdef class GreedyParser:
assert os.path.exists(model_dir) and os.path.isdir(model_dir) assert os.path.exists(model_dir) and os.path.isdir(model_dir)
self.cfg = Config.read(model_dir, 'config') self.cfg = Config.read(model_dir, 'config')
self.moves = TransitionSystem(self.cfg.left_labels, self.cfg.right_labels) self.moves = TransitionSystem(self.cfg.left_labels, self.cfg.right_labels)
hasty_templ, full_templ = get_templates(self.cfg.features) templates = get_templates(self.cfg.features)
self.model = Model(self.moves.n_moves, full_templ, model_dir) self.model = Model(self.moves.n_moves, templates, model_dir)
def __call__(self, Tokens tokens): def __call__(self, Tokens tokens):
cdef:
Transition guess
uint64_t state_key
if tokens.length == 0: if tokens.length == 0:
return 0 return 0
cdef atom_t[CONTEXT_SIZE] context cdef atom_t[CONTEXT_SIZE] context
cdef int n_feats cdef int n_feats
cdef Pool mem = Pool() cdef Pool mem = Pool()
cdef State* state = init_state(mem, tokens.data, tokens.length) cdef State* state = init_state(mem, tokens.data, tokens.length) # TODO
cdef Transition guess
while not is_final(state): while not is_final(state):
fill_context(context, state) fill_context(context, state)
scores = self.model.score(context) scores = self.model.score(context)
guess = self.moves.best_valid(scores, state) guess = self.moves.best_valid(scores, state)
guess.do(&guess, state) guess.do(&guess, state)
# Messily tell Tokens object the string names of the dependency labels tokens.set_parse(state.sent, self.moves.label_ids) # TODO
dep_strings = [None] * len(self.moves.label_ids)
for label, id_ in self.moves.label_ids.items():
dep_strings[id_] = label
tokens._dep_strings = tuple(dep_strings)
tokens.is_parsed = True
# TODO: Clean this up.
tokens._py_tokens = [None] * tokens.length
return 0 return 0
def train_sent(self, Tokens tokens, list gold_heads, list gold_labels, def train_sent(self, Tokens tokens, GoldParse gold, force_gold=False):
force_gold=False):
cdef: cdef:
int n_feats
const Feature* feats const Feature* feats
const weight_t* scores const weight_t* scores
Transition guess Transition guess
Transition gold Transition gold
atom_t[CONTEXT_SIZE] context
cdef int n_feats
cdef atom_t[CONTEXT_SIZE] context
cdef Pool mem = Pool() cdef Pool mem = Pool()
cdef int* heads_array = <int*>mem.alloc(tokens.length, sizeof(int))
cdef int* labels_array = <int*>mem.alloc(tokens.length, sizeof(int))
cdef int i
for i in range(tokens.length):
if gold_heads[i] is None:
heads_array[i] = -1
labels_array[i] = -1
else:
heads_array[i] = gold_heads[i]
labels_array[i] = self.moves.label_ids[gold_labels[i]]
py_words = [t.orth_ for t in tokens]
py_moves = ['S', 'D', 'L', 'R', 'BS', 'BR']
history = []
#print py_words
cdef State* state = init_state(mem, tokens.data, tokens.length) cdef State* state = init_state(mem, tokens.data, tokens.length)
while not is_final(state): while not is_final(state):
fill_context(context, state) fill_context(context, state)
scores = self.model.score(context) scores = self.model.score(context)
guess = self.moves.best_valid(scores, state) guess = self.moves.best_valid(scores, state)
best = self.moves.best_gold(&guess, scores, state, heads_array, labels_array) gold = self.moves.best_gold(scores, state, gold)
history.append((py_moves[best.move], print_state(state, py_words))) cost = guess.get_cost(&guess, state, gold)
self.model.update(context, guess.clas, best.clas, guess.cost) self.model.update(context, guess.clas, best.clas, cost)
if force_gold: if force_gold:
best.do(&best, state) best.do(&best, state)
else: else:
guess.do(&guess, state) guess.do(&guess, state)
cdef int n_corr = 0 n_corr = gold.heads_correct(state.sent, score_punct=True) # TODO
for i in range(tokens.length):
if gold_heads[i] != -1:
n_corr += (i + state.sent[i].head) == gold_heads[i]
if force_gold and n_corr != tokens.length: if force_gold and n_corr != tokens.length:
#print py_words
#print gold_heads
#for move, state_str in history:
# print move, state_str
#for i in range(tokens.length):
# print py_words[i], py_words[i + state.sent[i].head], py_words[gold_heads[i]]
raise OracleError raise OracleError
return n_corr return n_corr