* Tmp commit

This commit is contained in:
Matthew Honnibal 2015-05-11 16:12:03 +02:00
parent 9568ebed08
commit 03a6626545
6 changed files with 151 additions and 5 deletions

View File

@ -53,6 +53,7 @@ cdef struct Constituent:
int start
int end
int label
bint on_stack
cdef struct TokenC:

View File

@ -14,6 +14,7 @@ cdef struct State:
int sent_len
int stack_len
int ents_len
int ctnt_len
cdef int add_dep(const State *s, const int head, const int child, const int label) except -1

View File

@ -1,10 +1,11 @@
from __future__ import unicode_literals
from ._state cimport State
from ._state cimport has_head, get_idx, get_s0, get_n0
from ._state cimport has_head, get_idx, get_s0, get_n0, get_left, get_right
from ._state cimport is_final, at_eol, pop_stack, push_stack, add_dep
from ._state cimport head_in_buffer, children_in_buffer
from ._state cimport head_in_stack, children_in_stack
from ._state cimport count_left_kids
from ..structs cimport TokenC
@ -24,15 +25,23 @@ cdef enum:
REDUCE
LEFT
RIGHT
BREAK
CONSTITUENT
ADJUST
N_MOVES
MOVE_NAMES = [None] * N_MOVES
MOVE_NAMES[SHIFT] = 'S'
MOVE_NAMES[REDUCE] = 'D'
MOVE_NAMES[LEFT] = 'L'
MOVE_NAMES[RIGHT] = 'R'
MOVE_NAMES[BREAK] = 'B'
MOVE_NAMES[CONSTITUENT] = 'C'
MOVE_NAMES[ADJUST] = 'A'
cdef do_func_t[N_MOVES] do_funcs
@ -43,20 +52,29 @@ cdef class ArcEager(TransitionSystem):
@classmethod
def get_labels(cls, gold_parses):
move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {},
LEFT: {'ROOT': True}, BREAK: {'ROOT': True}}
for raw_text, segmented, (ids, words, tags, heads, labels, iob) in gold_parses:
LEFT: {'ROOT': True}, BREAK: {'ROOT': True},
CONSTITUENT: {}, ADJUST: {'': True}}
for raw_text, segmented, (ids, words, tags, heads, labels, iob), ctnts in gold_parses:
for child, head, label in zip(ids, heads, labels):
if label != 'ROOT':
if head < child:
move_labels[RIGHT][label] = True
elif head > child:
move_labels[LEFT][label] = True
for start, end, label in ctnts:
move_labels[CONSTITUENT][label] = True
return move_labels
cdef int preprocess_gold(self, GoldParse gold) except -1:
for i in range(gold.length):
gold.c_heads[i] = gold.heads[i]
gold.c_labels[i] = self.strings[gold.labels[i]]
for end, brackets in gold.brackets.items():
for start, label_strs in brackets.items():
gold.c_brackets[start][end] = 1
for label_str in label_strs:
# Add the encoded label to the set
gold.brackets[end][start].add(self.strings[label_str])
cdef Transition lookup_transition(self, object name) except *:
if '-' in name:
@ -104,6 +122,8 @@ cdef class ArcEager(TransitionSystem):
is_valid[LEFT] = _can_left(s)
is_valid[RIGHT] = _can_right(s)
is_valid[BREAK] = _can_break(s)
is_valid[CONSTITUENT] = _can_constituent(s)
is_valid[ADJUST] = _can_adjust(s)
cdef Transition best
cdef weight_t score = MIN_SCORE
cdef int i
@ -162,11 +182,42 @@ cdef int _do_break(const Transition* self, State* state) except -1:
push_stack(state)
cdef int _do_constituent(const Transition* self, State* state) except -1:
cdef const TokenC* s0 = get_s0(state)
if state.ctnt.head == get_idx(state, s0):
start = state.ctnt.start
else:
start = get_idx(state, s0)
state.ctnt += 1
state.ctnt.start = start
state.ctnt.end = s0.r_edge
state.ctnt.head = get_idx(state, s0)
state.ctnt.label = self.label
cdef int _do_adjust(const Transition* self, State* state) except -1:
cdef const TokenC* child
cdef const TokenC* s0 = get_s0(state)
cdef int n_left = count_left_kids(s0)
for i in range(1, n_left):
child = get_left(state, s0, i)
assert child is not NULL
if child.l_edge < state.ctnt.start:
state.ctnt.start = child.l_edge
break
else:
msg = ("Error moving bracket --- Move should be invalid if "
"no left edge to move to.")
raise Exception(msg)
do_funcs[SHIFT] = _do_shift
do_funcs[REDUCE] = _do_reduce
do_funcs[LEFT] = _do_left
do_funcs[RIGHT] = _do_right
do_funcs[BREAK] = _do_break
do_funcs[CONSTITUENT] = _do_constituent
do_funcs[ADJUST] = _do_adjust
cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) except -1:
@ -243,11 +294,72 @@ cdef int _break_cost(const Transition* self, const State* s, GoldParse gold) exc
return cost
cdef int _constituent_cost(const Transition* self, const State* s, GoldParse gold) except -1:
if not _can_constituent(s):
return 9000
# The gold standard is indexed by end, then by start, then a set of labels
brackets = gold.brackets(get_s0(s).r_edge, {})
if not brackets:
return 2 # 2 loss for bad bracket, only 1 for good bracket bad label
# Index the current brackets in the state
existing = set()
for i in range(s.ctnt_len):
if ctnt.end == s.r_edge and ctnt.label == self.label:
existing.add(ctnt.start)
cdef int loss = 2
cdef const TokenC* child
cdef const TokenC* s0 = get_s0(s)
cdef int n_left = count_left_kids(s0)
# Iterate over the possible start positions, and check whether we have a
# (start, end, label) match to the gold tree
for i in range(1, n_left):
child = get_left(s, s0, i)
if child.l_edge in brackets and child.l_edge not in existing:
if self.label in brackets[child.l_edge]
return 0
else:
loss = 1 # If we see the start position, set loss to 1
return loss
cdef int _adjust_cost(const Transition* self, const State* s, GoldParse gold) except -1:
if not _can_adjust(s):
return 9000
# The gold standard is indexed by end, then by start, then a set of labels
gold_starts = gold.brackets(get_s0(s).r_edge, {})
# Case 1: There are 0 brackets ending at this word.
# --> Cost is sunk, but must allow brackets to begin
if not gold_starts:
return 0
# Is the top bracket correct?
gold_labels = gold_starts.get(s.ctnt.start, set())
# TODO: Case where we have a unary rule
# TODO: Case where two brackets end on this word, with top bracket starting
# before
cdef const TokenC* child
cdef const TokenC* s0 = get_s0(s)
cdef int n_left = count_left_kids(s0)
cdef int i
# Iterate over the possible start positions, and check whether we have a
# (start, end, label) match to the gold tree
for i in range(1, n_left):
child = get_left(s, s0, i)
if child.l_edge in brackets:
if self.label in brackets[child.l_edge]:
return 0
else:
loss = 1 # If we see the start position, set loss to 1
return loss
get_cost_funcs[SHIFT] = _shift_cost
get_cost_funcs[REDUCE] = _reduce_cost
get_cost_funcs[LEFT] = _left_cost
get_cost_funcs[RIGHT] = _right_cost
get_cost_funcs[BREAK] = _break_cost
get_cost_funcs[CONSTITUENT] = _constituent_cost
get_cost_funcs[ADJUST] = _adjust_cost
cdef inline bint _can_shift(const State* s) nogil:
@ -288,3 +400,21 @@ cdef inline bint _can_break(const State* s) nogil:
else:
seen_headless = True
return True
cdef inline bint _can_constituent(const State* s) nogil:
return s.stack_len >= 1
cdef inline bint _can_adjust(const State* s) nogil:
# Need a left child to move the bracket to
cdef const TokenC* child
cdef const TokenC* s0 = get_s0(s)
cdef int n_left = count_left_kids(s0)
cdef int i
for i in range(1, n_left):
child = get_left(s, s0, i)
if child.l_edge < s.ctnt.start:
return True
else:
return False

View File

@ -16,10 +16,12 @@ cdef class GoldParse:
cdef readonly dict orths
cdef readonly list ner
cdef readonly list ents
cdef readonly dict brackets
cdef int* c_tags
cdef int* c_heads
cdef int* c_labels
cdef int** c_brackets
cdef Transition* c_ner
cdef int heads_correct(self, TokenC* tokens, bint score_punct=?) except -1

View File

@ -30,7 +30,7 @@ def read_json_file(loc):
paragraphs.append((paragraph['raw'],
tokenized,
(ids, words, tags, heads, labels, _iob_to_biluo(iob_ents)),
brackets))
paragraph.get('brackets', [])))
return paragraphs
@ -145,7 +145,7 @@ def _parse_line(line):
cdef class GoldParse:
def __init__(self, tokens, annot_tuples):
def __init__(self, tokens, annot_tuples, brackets=(,)):
self.mem = Pool()
self.loss = 0
self.length = len(tokens)
@ -155,6 +155,9 @@ cdef class GoldParse:
self.c_heads = <int*>self.mem.alloc(len(tokens), sizeof(int))
self.c_labels = <int*>self.mem.alloc(len(tokens), sizeof(int))
self.c_ner = <Transition*>self.mem.alloc(len(tokens), sizeof(Transition))
self.c_brackets = <int**>self.mem.alloc(len(tokens), sizeof(int*))
for i in range(len(tokens)):
self.c_brackets[i] = <int*>self.mem.alloc(len(tokens), sizeof(int))
self.tags = [None] * len(tokens)
self.heads = [-1] * len(tokens)
@ -199,6 +202,14 @@ cdef class GoldParse:
self.ner[i] = 'I-%s' % label
self.ner[end-1] = 'L-%s' % label
self.brackets = {}
for (start_idx, end_idx, label_str) in brackets:
if start_idx in idx_map and end_idx in idx_map:
start = idx_map[start_idx]
end = idx_map[end_idx]
self.brackets.setdefault(end, {}).setdefault(start, set())
self.brackets[end][start].add(label)
def __len__(self):
return self.length

View File

@ -95,6 +95,7 @@ cdef class GreedyParser:
return 0
def train(self, Tokens tokens, GoldParse gold):
py_words = [w.orth_ for w in tokens]
self.moves.preprocess_gold(gold)
cdef Pool mem = Pool()
cdef State* state = new_state(mem, tokens.data, tokens.length)