* Tmp commit

This commit is contained in:
Matthew Honnibal 2015-02-23 14:04:53 -05:00
parent 4f83c9b3d5
commit 10ed738df2
8 changed files with 46 additions and 68 deletions

View File

@ -1,4 +1,4 @@
from libc.string cimport memmove
from libc.string cimport memmove, memcpy
from cymem.cymem cimport Pool
from ..lexeme cimport EMPTY_LEXEME
@ -120,7 +120,9 @@ cdef State* init_state(Pool mem, const TokenC* sent, const int sent_len) except
s.stack[i] = -1
s.stack += (PADDING - 1)
assert s.stack[0] == -1
s.sent = <TokenC*>mem.alloc(sent_len, sizeof(TokenC))
state_sent = <TokenC*>mem.alloc(padded_len, sizeof(TokenC))
memcpy(state_sent, sent - PADDING, padded_len * sizeof(TokenC))
s.sent = state_sent + PADDING
s.stack_len = 0
s.i = 0
s.sent_len = sent_len

View File

@ -35,7 +35,8 @@ cdef get_cost_func_t[N_MOVES] get_cost_funcs
cdef class ArcEager(TransitionSystem):
@classmethod
def get_labels(cls, gold_parses):
labels = {RIGHT: {}, LEFT: {}}
labels = {SHIFT: {0: True}, REDUCE: {0: True}, RIGHT: {0: True},
LEFT: {0: True}, BREAK: {0: True}}
for parse in gold_parses:
for i, (head, label) in enumerate(zip(parse.heads, parse.labels)):
if head > i:
@ -128,7 +129,7 @@ cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) exc
cost += head_in_stack(s, s.i, gold.c_heads)
cost += children_in_stack(s, s.i, gold.c_heads)
if NON_MONOTONIC:
cost += gold[s.stack[0]] == s.i
cost += gold.c_heads[s.stack[0]] == s.i
# If we can break, and there's no cost to doing so, we should
if _can_break(s) and _break_cost(self, s, gold) == 0:
cost += 1
@ -138,29 +139,29 @@ cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) exc
cdef int _right_cost(const Transition* self, const State* s, GoldParse gold) except -1:
assert s.stack_len >= 1
cost = 0
if gold[s.i] == s.stack[0]:
if gold.c_heads[s.i] == s.stack[0]:
cost += self.label != gold.c_labels[s.i]
return cost
cost += head_in_buffer(s, s.i, gold.c_heads)
cost += children_in_stack(s, s.i, gold.c_heads)
cost += head_in_stack(s, s.i, gold.c_heads)
if NON_MONOTONIC:
cost += gold[s.stack[0]] == s.i
cost += gold.c_heads[s.stack[0]] == s.i
return cost
cdef int _left_cost(const Transition* self, const State* s, GoldParse gold) except -1:
assert s.stack_len >= 1
cost = 0
if gold[s.stack[0]] == s.i:
cost += self.label != gold.c_labels[s.top]
if gold.c_heads[s.stack[0]] == s.i:
cost += self.label != gold.c_labels[s.stack[0]]
return cost
cost += head_in_buffer(s, s.stack[0], gold.c_heads)
cost += children_in_buffer(s, s.stack[0], gold.c_heads)
if NON_MONOTONIC and s.stack_len >= 2:
cost += gold[s.stack[0]] == s.stack[-1]
cost += gold[s.stack[0]] == s.stack[0]
cost += gold.c_heads[s.stack[0]] == s.stack[-1]
cost += gold.c_heads[s.stack[0]] == s.stack[0]
return cost

View File

@ -11,12 +11,12 @@ cdef class GoldParse:
cdef int length
cdef int loss
cdef unicode raw_text
cdef list words
cdef list ids
cdef list tags
cdef list heads
cdef list labels
cdef readonly unicode raw_text
cdef readonly list words
cdef readonly list ids
cdef readonly list tags
cdef readonly list heads
cdef readonly list labels
cdef int heads_correct(self, TokenC* tokens, bint score_punct=?) except -1

View File

@ -13,7 +13,10 @@ cdef class GoldParse:
self.c_labels = <int*>self.mem.alloc(self.length, sizeof(int))
cdef int heads_correct(self, TokenC* tokens, bint score_punct=False) except -1:
pass
n = 0
for i in range(self.length):
n += (i + tokens[i].head) == self.c_heads[i]
return n
@classmethod
def from_conll(cls, unicode sent_str):
@ -57,7 +60,7 @@ cdef class GoldParse:
tags.append(pos_string)
tokenized = [sent_str.replace('<SEP>', ' ').split(' ')
for sent_str in tok_text.split('<SENT>')]
return cls(raw_text, tokenized, ids, words, tags, heads, labels)
return cls(raw_text, words, ids, tags, heads, labels)
def align_to_tokens(self, tokens, label_ids):
orig_words = list(self.words)
@ -70,9 +73,7 @@ cdef class GoldParse:
for token in tokens:
while annot and token.idx > annot[0][0]:
miss_id, miss_tag, miss_head, miss_label = annot.pop(0)
miss_w = self.words.pop(0)
if not is_punct_label(miss_label):
missed.append(miss_w)
self.loss += 1
if not annot:
self.tags.append(None)
@ -85,15 +86,22 @@ cdef class GoldParse:
self.heads.append(head)
self.labels.append(label)
annot.pop(0)
self.words.pop(0)
elif token.idx < id_:
self.tags.append(None)
self.heads.append(None)
self.labels.append(None)
else:
raise StandardError
self.length = len(tokens)
self.c_heads = <int*>self.mem.alloc(self.length, sizeof(int))
self.c_labels = <int*>self.mem.alloc(self.length, sizeof(int))
self.ids = [token.idx for token in tokens]
mapped_heads = _map_indices_to_tokens(self.ids, self.heads)
for i in range(self.length):
if mapped_heads[i] is None:
self.c_heads[i] = -1
self.c_labels[i] = -1
else:
self.c_heads[i] = mapped_heads[i]
self.c_labels[i] = label_ids[self.labels[i]]
return self.loss
@ -125,38 +133,3 @@ def _parse_line(line):
head_idx = int(pieces[6])
label = pieces[7]
return id_, word, pos, head_idx, label
"""
# TODO
def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
global loss
nlp = Language()
n_corr = 0
pos_corr = 0
n_tokens = 0
total = 0
skipped = 0
loss = 0
with codecs.open(dev_loc, 'r', 'utf8') as file_:
#paragraphs = read_tokenized_gold(file_)
paragraphs = read_docparse_gold(file_)
for tokens, tag_strs, heads, labels in iter_data(paragraphs, nlp.tokenizer,
gold_preproc=gold_preproc):
assert len(tokens) == len(labels)
nlp.tagger(tokens)
nlp.parser(tokens)
for i, token in enumerate(tokens):
pos_corr += token.tag_ == tag_strs[i]
n_tokens += 1
if heads[i] is None:
skipped += 1
continue
if is_punct_label(labels[i]):
continue
n_corr += token.head.i == heads[i]
total += 1
print loss, skipped, (loss+skipped + total)
print pos_corr / n_tokens
return float(n_corr) / (total + loss)
"""

View File

@ -6,6 +6,6 @@ from ..tokens cimport Tokens, TokenC
cdef class GreedyParser:
cdef object cfg
cdef readonly object cfg
cdef readonly Model model
cdef TransitionSystem moves
cdef readonly TransitionSystem moves

View File

@ -98,7 +98,6 @@ cdef class GreedyParser:
cdef Pool mem = Pool()
cdef State* state = init_state(mem, tokens.data, tokens.length)
while not is_final(state):
fill_context(context, state)
scores = self.model.score(context)

View File

@ -27,6 +27,7 @@ cdef class TransitionSystem:
cdef readonly dict label_ids
cdef Pool mem
cdef const Transition* c
cdef readonly int n_moves
cdef Transition init_transition(self, int clas, int move, int label) except *

View File

@ -14,14 +14,16 @@ class OracleError(Exception):
cdef class TransitionSystem:
def __init__(self, dict labels_by_action):
self.mem = Pool()
self.n_moves = sum(len(labels) for labels in labels_by_action.items())
self.n_moves = sum(len(labels) for labels in labels_by_action.values())
moves = <Transition*>self.mem.alloc(self.n_moves, sizeof(Transition))
cdef int i = 0
self.label_ids = {}
cdef int label_id
self.label_ids = {'ROOT': 0, 'MISSING': -1}
for action, label_strs in sorted(labels_by_action.items()):
for label_str in sorted(label_strs):
label_str = unicode(label_str)
label_id = self.label_ids.setdefault(label_str, len(self.label_ids))
moves[i] = self.init_transition(i, action, label_id)
moves[i] = self.init_transition(i, int(action), label_id)
i += 1
self.c = moves