diff --git a/spacy/structs.pxd b/spacy/structs.pxd index 6a15b8951..8b1a8d942 100644 --- a/spacy/structs.pxd +++ b/spacy/structs.pxd @@ -53,6 +53,7 @@ cdef struct Constituent: int start int end int label + bint on_stack cdef struct TokenC: diff --git a/spacy/syntax/_state.pxd b/spacy/syntax/_state.pxd index a1f17b94c..a66140b0b 100644 --- a/spacy/syntax/_state.pxd +++ b/spacy/syntax/_state.pxd @@ -14,6 +14,7 @@ cdef struct State: int sent_len int stack_len int ents_len + int ctnt_len cdef int add_dep(const State *s, const int head, const int child, const int label) except -1 diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 7d3d36347..d24848715 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -1,10 +1,11 @@ from __future__ import unicode_literals from ._state cimport State -from ._state cimport has_head, get_idx, get_s0, get_n0 +from ._state cimport has_head, get_idx, get_s0, get_n0, get_left, get_right from ._state cimport is_final, at_eol, pop_stack, push_stack, add_dep from ._state cimport head_in_buffer, children_in_buffer from ._state cimport head_in_stack, children_in_stack +from ._state cimport count_left_kids from ..structs cimport TokenC @@ -24,15 +25,23 @@ cdef enum: REDUCE LEFT RIGHT + BREAK + + CONSTITUENT + ADJUST + N_MOVES + MOVE_NAMES = [None] * N_MOVES MOVE_NAMES[SHIFT] = 'S' MOVE_NAMES[REDUCE] = 'D' MOVE_NAMES[LEFT] = 'L' MOVE_NAMES[RIGHT] = 'R' MOVE_NAMES[BREAK] = 'B' +MOVE_NAMES[CONSTITUENT] = 'C' +MOVE_NAMES[ADJUST] = 'A' cdef do_func_t[N_MOVES] do_funcs @@ -43,20 +52,29 @@ cdef class ArcEager(TransitionSystem): @classmethod def get_labels(cls, gold_parses): move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {}, - LEFT: {'ROOT': True}, BREAK: {'ROOT': True}} - for raw_text, segmented, (ids, words, tags, heads, labels, iob) in gold_parses: + LEFT: {'ROOT': True}, BREAK: {'ROOT': True}, + CONSTITUENT: {}, ADJUST: {'': True}} + for raw_text, segmented, (ids, words, tags, heads, labels, iob), ctnts in gold_parses: for child, head, label in zip(ids, heads, labels): if label != 'ROOT': if head < child: move_labels[RIGHT][label] = True elif head > child: move_labels[LEFT][label] = True + for start, end, label in ctnts: + move_labels[CONSTITUENT][label] = True return move_labels cdef int preprocess_gold(self, GoldParse gold) except -1: for i in range(gold.length): gold.c_heads[i] = gold.heads[i] gold.c_labels[i] = self.strings[gold.labels[i]] + for end, brackets in gold.brackets.items(): + for start, label_strs in brackets.items(): + gold.c_brackets[start][end] = 1 + for label_str in label_strs: + # Add the encoded label to the set + gold.brackets[end][start].add(self.strings[label_str]) cdef Transition lookup_transition(self, object name) except *: if '-' in name: @@ -104,6 +122,8 @@ cdef class ArcEager(TransitionSystem): is_valid[LEFT] = _can_left(s) is_valid[RIGHT] = _can_right(s) is_valid[BREAK] = _can_break(s) + is_valid[CONSTITUENT] = _can_constituent(s) + is_valid[ADJUST] = _can_adjust(s) cdef Transition best cdef weight_t score = MIN_SCORE cdef int i @@ -162,11 +182,42 @@ cdef int _do_break(const Transition* self, State* state) except -1: push_stack(state) +cdef int _do_constituent(const Transition* self, State* state) except -1: + cdef const TokenC* s0 = get_s0(state) + if state.ctnt.head == get_idx(state, s0): + start = state.ctnt.start + else: + start = get_idx(state, s0) + state.ctnt += 1 + state.ctnt.start = start + state.ctnt.end = s0.r_edge + state.ctnt.head = get_idx(state, s0) + state.ctnt.label = self.label + + +cdef int _do_adjust(const Transition* self, State* state) except -1: + cdef const TokenC* child + cdef const TokenC* s0 = get_s0(state) + cdef int n_left = count_left_kids(s0) + for i in range(1, n_left): + child = get_left(state, s0, i) + assert child is not NULL + if child.l_edge < state.ctnt.start: + state.ctnt.start = child.l_edge + break + else: + msg = ("Error moving bracket --- Move should be invalid if " + "no left edge to move to.") + raise Exception(msg) + + do_funcs[SHIFT] = _do_shift do_funcs[REDUCE] = _do_reduce do_funcs[LEFT] = _do_left do_funcs[RIGHT] = _do_right do_funcs[BREAK] = _do_break +do_funcs[CONSTITUENT] = _do_constituent +do_funcs[ADJUST] = _do_adjust cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) except -1: @@ -243,11 +294,72 @@ cdef int _break_cost(const Transition* self, const State* s, GoldParse gold) exc return cost +cdef int _constituent_cost(const Transition* self, const State* s, GoldParse gold) except -1: + if not _can_constituent(s): + return 9000 + # The gold standard is indexed by end, then by start, then a set of labels + brackets = gold.brackets(get_s0(s).r_edge, {}) + if not brackets: + return 2 # 2 loss for bad bracket, only 1 for good bracket bad label + # Index the current brackets in the state + existing = set() + for i in range(s.ctnt_len): + if ctnt.end == s.r_edge and ctnt.label == self.label: + existing.add(ctnt.start) + cdef int loss = 2 + cdef const TokenC* child + cdef const TokenC* s0 = get_s0(s) + cdef int n_left = count_left_kids(s0) + # Iterate over the possible start positions, and check whether we have a + # (start, end, label) match to the gold tree + for i in range(1, n_left): + child = get_left(s, s0, i) + if child.l_edge in brackets and child.l_edge not in existing: + if self.label in brackets[child.l_edge] + return 0 + else: + loss = 1 # If we see the start position, set loss to 1 + return loss + + +cdef int _adjust_cost(const Transition* self, const State* s, GoldParse gold) except -1: + if not _can_adjust(s): + return 9000 + # The gold standard is indexed by end, then by start, then a set of labels + gold_starts = gold.brackets(get_s0(s).r_edge, {}) + # Case 1: There are 0 brackets ending at this word. + # --> Cost is sunk, but must allow brackets to begin + if not gold_starts: + return 0 + # Is the top bracket correct? + gold_labels = gold_starts.get(s.ctnt.start, set()) + # TODO: Case where we have a unary rule + # TODO: Case where two brackets end on this word, with top bracket starting + # before + + cdef const TokenC* child + cdef const TokenC* s0 = get_s0(s) + cdef int n_left = count_left_kids(s0) + cdef int i + # Iterate over the possible start positions, and check whether we have a + # (start, end, label) match to the gold tree + for i in range(1, n_left): + child = get_left(s, s0, i) + if child.l_edge in brackets: + if self.label in brackets[child.l_edge]: + return 0 + else: + loss = 1 # If we see the start position, set loss to 1 + return loss + + get_cost_funcs[SHIFT] = _shift_cost get_cost_funcs[REDUCE] = _reduce_cost get_cost_funcs[LEFT] = _left_cost get_cost_funcs[RIGHT] = _right_cost get_cost_funcs[BREAK] = _break_cost +get_cost_funcs[CONSTITUENT] = _constituent_cost +get_cost_funcs[ADJUST] = _adjust_cost cdef inline bint _can_shift(const State* s) nogil: @@ -288,3 +400,21 @@ cdef inline bint _can_break(const State* s) nogil: else: seen_headless = True return True + + +cdef inline bint _can_constituent(const State* s) nogil: + return s.stack_len >= 1 + + +cdef inline bint _can_adjust(const State* s) nogil: + # Need a left child to move the bracket to + cdef const TokenC* child + cdef const TokenC* s0 = get_s0(s) + cdef int n_left = count_left_kids(s0) + cdef int i + for i in range(1, n_left): + child = get_left(s, s0, i) + if child.l_edge < s.ctnt.start: + return True + else: + return False diff --git a/spacy/syntax/conll.pxd b/spacy/syntax/conll.pxd index 815920ea6..508c575c0 100644 --- a/spacy/syntax/conll.pxd +++ b/spacy/syntax/conll.pxd @@ -16,10 +16,12 @@ cdef class GoldParse: cdef readonly dict orths cdef readonly list ner cdef readonly list ents + cdef readonly dict brackets cdef int* c_tags cdef int* c_heads cdef int* c_labels + cdef int** c_brackets cdef Transition* c_ner cdef int heads_correct(self, TokenC* tokens, bint score_punct=?) except -1 diff --git a/spacy/syntax/conll.pyx b/spacy/syntax/conll.pyx index ff3af58c3..c4afeb02d 100644 --- a/spacy/syntax/conll.pyx +++ b/spacy/syntax/conll.pyx @@ -30,7 +30,7 @@ def read_json_file(loc): paragraphs.append((paragraph['raw'], tokenized, (ids, words, tags, heads, labels, _iob_to_biluo(iob_ents)), - brackets)) + paragraph.get('brackets', []))) return paragraphs @@ -145,7 +145,7 @@ def _parse_line(line): cdef class GoldParse: - def __init__(self, tokens, annot_tuples): + def __init__(self, tokens, annot_tuples, brackets=(,)): self.mem = Pool() self.loss = 0 self.length = len(tokens) @@ -155,6 +155,9 @@ cdef class GoldParse: self.c_heads = self.mem.alloc(len(tokens), sizeof(int)) self.c_labels = self.mem.alloc(len(tokens), sizeof(int)) self.c_ner = self.mem.alloc(len(tokens), sizeof(Transition)) + self.c_brackets = self.mem.alloc(len(tokens), sizeof(int*)) + for i in range(len(tokens)): + self.c_brackets[i] = self.mem.alloc(len(tokens), sizeof(int)) self.tags = [None] * len(tokens) self.heads = [-1] * len(tokens) @@ -199,6 +202,14 @@ cdef class GoldParse: self.ner[i] = 'I-%s' % label self.ner[end-1] = 'L-%s' % label + self.brackets = {} + for (start_idx, end_idx, label_str) in brackets: + if start_idx in idx_map and end_idx in idx_map: + start = idx_map[start_idx] + end = idx_map[end_idx] + self.brackets.setdefault(end, {}).setdefault(start, set()) + self.brackets[end][start].add(label) + def __len__(self): return self.length diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index 09495ae92..36acce3de 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -95,6 +95,7 @@ cdef class GreedyParser: return 0 def train(self, Tokens tokens, GoldParse gold): + py_words = [w.orth_ for w in tokens] self.moves.preprocess_gold(gold) cdef Pool mem = Pool() cdef State* state = new_state(mem, tokens.data, tokens.length)