From c7e3dfc1dc581a0dbe3cfc7adbd48d8df7c7894d Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 8 Jun 2015 14:49:04 +0200 Subject: [PATCH] * Don't automatically push words when stack is empty, as it messes up beam parsing. Add hash method to beam state. --- spacy/syntax/_state.pyx | 87 +++++++++++++++++++++++-------- spacy/syntax/arc_eager.pyx | 104 ++++++++++++++++++------------------- spacy/syntax/parser.pyx | 31 ++++++++--- 3 files changed, 142 insertions(+), 80 deletions(-) diff --git a/spacy/syntax/_state.pyx b/spacy/syntax/_state.pyx index 3e28a6cd4..3a876df2e 100644 --- a/spacy/syntax/_state.pyx +++ b/spacy/syntax/_state.pyx @@ -61,8 +61,8 @@ cdef int pop_stack(State *s) except -1: assert s.stack_len >= 1 s.stack_len -= 1 s.stack -= 1 - if s.stack_len == 0 and not at_eol(s): - push_stack(s) + #if s.stack_len == 0 and not at_eol(s): + # push_stack(s) cdef int push_stack(State *s) except -1: @@ -114,27 +114,29 @@ cdef bint has_head(const TokenC* t) nogil: cdef const TokenC* get_left(const State* s, const TokenC* head, const int idx) nogil: - cdef uint32_t kids = head.l_kids - if kids == 0: - return NULL - cdef int offset = _nth_significant_bit(kids, idx) - cdef const TokenC* child = head - offset - if child >= s.sent: - return child - else: - return NULL + return _new_get_left(s, head, idx) + #cdef uint32_t kids = head.l_kids + #if kids == 0: + # return NULL + #cdef int offset = _nth_significant_bit(kids, idx) + #cdef const TokenC* child = head - offset + #if child >= s.sent: + # return child + ##else: + # return NULL cdef const TokenC* get_right(const State* s, const TokenC* head, const int idx) nogil: - cdef uint32_t kids = head.r_kids - if kids == 0: - return NULL - cdef int offset = _nth_significant_bit(kids, idx) - cdef const TokenC* child = head + offset - if child < (s.sent + s.sent_len): - return child - else: - return NULL + return _new_get_right(s, head, idx) + #cdef uint32_t kids = head.r_kids + #if kids == 0: + # return NULL + #cdef int offset = _nth_significant_bit(kids, idx) + #cdef const TokenC* child = head + offset + #if child < (s.sent + s.sent_len): + # return child + #else: + # return NULL cdef int count_left_kids(const TokenC* head) nogil: @@ -190,7 +192,7 @@ cdef int copy_state(State* dest, const State* src) except -1: # From https://en.wikipedia.org/wiki/Hamming_weight cdef inline uint32_t _popcount(uint32_t x) nogil: """Find number of non-zero bits.""" - cdef int count = 0 + cdef uint32_t count = 0 while x != 0: x &= x - 1 count += 1 @@ -198,10 +200,51 @@ cdef inline uint32_t _popcount(uint32_t x) nogil: cdef inline uint32_t _nth_significant_bit(uint32_t bits, int n) nogil: - cdef int i + cdef uint32_t i for i in range(32): if bits & (1 << i): n -= 1 if n < 1: return i return 0 + + +cdef const TokenC* _new_get_left(const State* s, const TokenC* target, int idx) nogil: + if idx < 1: + return NULL + cdef const TokenC* ptr = s.sent + while ptr < target: + # If this head is still to the right of us, we can skip to it + # No token that's between this token and this head could be our + # child. + if (ptr.head >= 1) and (ptr + ptr.head) < target: + ptr += ptr.head + + elif ptr + ptr.head == target: + idx -= 1 + if idx == 0: + return ptr + ptr += 1 + else: + ptr += 1 + return NULL + + +cdef const TokenC* _new_get_right(const State* s, const TokenC* target, int idx) nogil: + if idx < 1: + return NULL + cdef const TokenC* ptr = s.sent + (s.sent_len - 1) + while ptr > target: + # If this head is still to the right of us, we can skip to it + # No token that's between this token and this head could be our + # child. + if (ptr.head < 0) and ((ptr + ptr.head) > target): + ptr += ptr.head + elif ptr + ptr.head == target: + idx -= 1 + if idx == 0: + return ptr + ptr -= 1 + else: + ptr -= 1 + return NULL diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 855535f4e..afa05bd9a 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -55,6 +55,8 @@ cdef int push_cost(const State* st, const GoldParseC* gold, int target) except - cdef int cost = 0 cost += head_in_stack(st, target, gold.heads) cost += children_in_stack(st, target, gold.heads) + # If we can Break, we shouldn't push + cost += Break.is_valid(st, -1) and Break.move_cost(st, gold) == 0 return cost @@ -65,15 +67,42 @@ cdef int pop_cost(const State* st, const GoldParseC* gold, int target) except -1 return cost -cdef int arc_cost(const GoldParseC* gold, int head, int child, int label) except -1: - if gold.heads[child] != head: +cdef int arc_cost(const State* st, const GoldParseC* gold, int head, int child) except -1: + if arc_is_gold(gold, head, child): return 0 - elif gold.labels[child] == -1: - return 0 - elif gold.labels[child] == label: - return 0 - else: + elif (child + st.sent[child].head) == gold.heads[child]: return 1 + elif gold.heads[child] >= st.i: + return 1 + else: + return 0 + + + +cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) except -1: + if gold.labels[child] == -1: + return True + elif _is_gold_root(gold, head) and _is_gold_root(gold, child): + return True + elif gold.heads[child] == head: + return True + else: + return False + + +cdef bint label_is_gold(const GoldParseC* gold, int head, int child, int label) except -1: + if gold.labels[child] == -1: + return True + elif label == -1: + return True + elif gold.labels[child] == label: + return True + else: + return False + + +cdef bint _is_gold_root(const GoldParseC* gold, int word) except -1: + return gold.labels[word] == -1 or gold.heads[word] == word cdef class Shift: @@ -96,11 +125,7 @@ cdef class Shift: @staticmethod cdef int move_cost(const State* s, const GoldParseC* gold) except -1: - cdef int cost = push_cost(s, gold, s.i) - # If we can break, and there's no cost to doing so, we should - if Break.is_valid(s, -1) and Break.cost(s, gold, -1) == 0: - cost += 1 - return cost + return push_cost(s, gold, s.i) @staticmethod cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1: @@ -117,7 +142,7 @@ cdef class Reduce: @staticmethod cdef int transition(State* state, int label) except -1: - if NON_MONOTONIC and not has_head(get_s0(state)): + if NON_MONOTONIC and not has_head(get_s0(state)) and state.stack_len >= 2: add_dep(state, state.stack[-1], state.stack[0], get_s0(state).dep) pop_stack(state) @@ -139,7 +164,6 @@ cdef class Reduce: return 0 - cdef class LeftArc: @staticmethod cdef bint is_valid(const State* s, int label) except -1: @@ -167,31 +191,14 @@ cdef class LeftArc: cdef int move_cost(const State* s, const GoldParseC* gold) except -1: if not LeftArc.is_valid(s, -1): return 9000 - cdef int cost = 0 - if gold.heads[s.stack[0]] == s.i: - return cost - elif at_eol(s): - # Are we root? - if gold.labels[s.stack[0]] != -1: - # If we're at EOL, prefer to reduce or break over left-arc - if Reduce.is_valid(s, -1) or Break.is_valid(s, -1): - cost += gold.heads[s.stack[0]] != s.stack[0] - return cost - cost += head_in_buffer(s, s.stack[0], gold.heads) - cost += children_in_buffer(s, s.stack[0], gold.heads) - if NON_MONOTONIC and s.stack_len >= 2: - cost += gold.heads[s.stack[0]] == s.stack[-1] - if gold.labels[s.stack[0]] != -1: - cost += gold.heads[s.stack[0]] == s.stack[0] - return cost + elif arc_is_gold(gold, s.i, s.stack[0]): + return 0 + else: + return pop_cost(s, gold, s.stack[0]) + arc_cost(s, gold, s.i, s.stack[0]) @staticmethod cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1: - if label == -1 or gold.labels[s.stack[0]] == -1: - return 0 - if gold.heads[s.stack[0]] == s.i and label != gold.labels[s.stack[0]]: - return 1 - return 0 + return arc_is_gold(gold, s.i, s.stack[0]) and not label_is_gold(gold, s.i, s.stack[0], label) cdef class RightArc: @@ -212,21 +219,14 @@ cdef class RightArc: @staticmethod cdef int move_cost(const State* s, const GoldParseC* gold) except -1: - return push_cost(s, gold, s.i) - (gold.heads[s.i] == s.stack[0]) + if arc_is_gold(gold, s.stack[0], s.i): + return 0 + else: + return push_cost(s, gold, s.i) + arc_cost(s, gold, s.stack[0], s.i) @staticmethod cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1: - return arc_cost(gold, s.stack[0], s.i, label) - #cdef int cost = 0 - #if gold.heads[s.i] == s.stack[0]: - # cost += label != -1 and label != gold.labels[s.i] - # return cost - # This indicates missing head - #if gold.labels[s.i] != -1: - # cost += head_in_buffer(s, s.i, gold.heads) - #cost += children_in_stack(s, s.i, gold.heads) - #cost += head_in_stack(s, s.i, gold.heads) - #return cost + return arc_is_gold(gold, s.stack[0], s.i) and not label_is_gold(gold, s.stack[0], s.i, label) cdef class Break: @@ -237,8 +237,10 @@ cdef class Break: return False elif at_eol(s): return False - #elif NON_MONOTONIC: - # return True + elif s.stack_len < 1: + return False + elif NON_MONOTONIC: + return True else: # In the Break transition paper, they have this constraint that prevents # Break if stack is disconnected. But, if we're doing non-monotonic parsing, @@ -262,8 +264,6 @@ cdef class Break: get_s0(state).dep = label state.stack -= 1 state.stack_len -= 1 - if not at_eol(state): - push_stack(state) @staticmethod cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index 639f91c03..47921563b 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -14,7 +14,7 @@ import json from cymem.cymem cimport Pool, Address from murmurhash.mrmr cimport hash64 -from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t +from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t from util import Config @@ -34,7 +34,7 @@ from ..strings cimport StringStore from .arc_eager cimport TransitionSystem, Transition from .transition_system import OracleError -from ._state cimport State, new_state, copy_state, is_final, push_stack +from ._state cimport State, new_state, copy_state, is_final, push_stack, get_left, get_n0 from ..gold cimport GoldParse from . import _parse_features @@ -83,14 +83,14 @@ cdef class Parser: def __call__(self, Tokens tokens): if tokens.length == 0: return 0 - if self.cfg.get('beam_width', 1) <= 1: + if self.cfg.get('beam_width', 1) < 1: self._greedy_parse(tokens) else: self._beam_parse(tokens) def train(self, Tokens tokens, GoldParse gold): self.moves.preprocess_gold(gold) - if self.cfg.beam_width <= 1: + if self.cfg.beam_width < 1: return self._greedy_train(tokens, gold) else: return self._beam_train(tokens, gold) @@ -185,8 +185,7 @@ cdef class Parser: if follow_gold: for j in range(self.moves.n_moves): beam.is_valid[i][j] *= beam.costs[i][j] == 0 - beam.advance(_transition_state, self.moves.c) - state = beam.at(0) + beam.advance(_transition_state, _hash_state, self.moves.c) beam.check_done(_check_final_state, NULL) def _count_feats(self, dict counts, Tokens tokens, list hist, int inc): @@ -222,3 +221,23 @@ cdef void* _init_state(Pool mem, int length, void* tokens) except NULL: cdef int _check_final_state(void* state, void* extra_args) except -1: return is_final(state) + + +cdef hash_t _hash_state(void* _state, void* _) except 0: + state = _state + cdef atom_t[10] rep + + rep[0] = state.stack[0] if state.stack_len >= 1 else 0 + rep[1] = state.stack[-1] if state.stack_len >= 2 else 0 + rep[2] = state.stack[-2] if state.stack_len >= 3 else 0 + rep[3] = state.i + rep[4] = state.sent[state.stack[0]].l_kids if state.stack_len >= 1 else 0 + rep[5] = state.sent[state.stack[0]].r_kids if state.stack_len >= 1 else 0 + rep[6] = state.sent[state.stack[0]].dep if state.stack_len >= 1 else 0 + rep[7] = state.sent[state.stack[-1]].dep if state.stack_len >= 2 else 0 + if get_left(state, get_n0(state), 1) != NULL: + rep[8] = get_left(state, get_n0(state), 1).dep + else: + rep[8] = 0 + rep[9] = state.sent[state.i].l_kids + return hash64(rep, sizeof(atom_t) * 10, 0)