* Don't automatically push words when stack is empty, as it messes up beam parsing. Add hash method to beam state.

This commit is contained in:
Matthew Honnibal 2015-06-08 14:49:04 +02:00
parent d51a86478e
commit c7e3dfc1dc
3 changed files with 142 additions and 80 deletions

View File

@ -61,8 +61,8 @@ cdef int pop_stack(State *s) except -1:
assert s.stack_len >= 1 assert s.stack_len >= 1
s.stack_len -= 1 s.stack_len -= 1
s.stack -= 1 s.stack -= 1
if s.stack_len == 0 and not at_eol(s): #if s.stack_len == 0 and not at_eol(s):
push_stack(s) # push_stack(s)
cdef int push_stack(State *s) except -1: cdef int push_stack(State *s) except -1:
@ -114,27 +114,29 @@ cdef bint has_head(const TokenC* t) nogil:
cdef const TokenC* get_left(const State* s, const TokenC* head, const int idx) nogil: cdef const TokenC* get_left(const State* s, const TokenC* head, const int idx) nogil:
cdef uint32_t kids = head.l_kids return _new_get_left(s, head, idx)
if kids == 0: #cdef uint32_t kids = head.l_kids
return NULL #if kids == 0:
cdef int offset = _nth_significant_bit(kids, idx) # return NULL
cdef const TokenC* child = head - offset #cdef int offset = _nth_significant_bit(kids, idx)
if child >= s.sent: #cdef const TokenC* child = head - offset
return child #if child >= s.sent:
else: # return child
return NULL ##else:
# return NULL
cdef const TokenC* get_right(const State* s, const TokenC* head, const int idx) nogil: cdef const TokenC* get_right(const State* s, const TokenC* head, const int idx) nogil:
cdef uint32_t kids = head.r_kids return _new_get_right(s, head, idx)
if kids == 0: #cdef uint32_t kids = head.r_kids
return NULL #if kids == 0:
cdef int offset = _nth_significant_bit(kids, idx) # return NULL
cdef const TokenC* child = head + offset #cdef int offset = _nth_significant_bit(kids, idx)
if child < (s.sent + s.sent_len): #cdef const TokenC* child = head + offset
return child #if child < (s.sent + s.sent_len):
else: # return child
return NULL #else:
# return NULL
cdef int count_left_kids(const TokenC* head) nogil: cdef int count_left_kids(const TokenC* head) nogil:
@ -190,7 +192,7 @@ cdef int copy_state(State* dest, const State* src) except -1:
# From https://en.wikipedia.org/wiki/Hamming_weight # From https://en.wikipedia.org/wiki/Hamming_weight
cdef inline uint32_t _popcount(uint32_t x) nogil: cdef inline uint32_t _popcount(uint32_t x) nogil:
"""Find number of non-zero bits.""" """Find number of non-zero bits."""
cdef int count = 0 cdef uint32_t count = 0
while x != 0: while x != 0:
x &= x - 1 x &= x - 1
count += 1 count += 1
@ -198,10 +200,51 @@ cdef inline uint32_t _popcount(uint32_t x) nogil:
cdef inline uint32_t _nth_significant_bit(uint32_t bits, int n) nogil: cdef inline uint32_t _nth_significant_bit(uint32_t bits, int n) nogil:
cdef int i cdef uint32_t i
for i in range(32): for i in range(32):
if bits & (1 << i): if bits & (1 << i):
n -= 1 n -= 1
if n < 1: if n < 1:
return i return i
return 0 return 0
cdef const TokenC* _new_get_left(const State* s, const TokenC* target, int idx) nogil:
if idx < 1:
return NULL
cdef const TokenC* ptr = s.sent
while ptr < target:
# If this head is still to the right of us, we can skip to it
# No token that's between this token and this head could be our
# child.
if (ptr.head >= 1) and (ptr + ptr.head) < target:
ptr += ptr.head
elif ptr + ptr.head == target:
idx -= 1
if idx == 0:
return ptr
ptr += 1
else:
ptr += 1
return NULL
cdef const TokenC* _new_get_right(const State* s, const TokenC* target, int idx) nogil:
if idx < 1:
return NULL
cdef const TokenC* ptr = s.sent + (s.sent_len - 1)
while ptr > target:
# If this head is still to the right of us, we can skip to it
# No token that's between this token and this head could be our
# child.
if (ptr.head < 0) and ((ptr + ptr.head) > target):
ptr += ptr.head
elif ptr + ptr.head == target:
idx -= 1
if idx == 0:
return ptr
ptr -= 1
else:
ptr -= 1
return NULL

View File

@ -55,6 +55,8 @@ cdef int push_cost(const State* st, const GoldParseC* gold, int target) except -
cdef int cost = 0 cdef int cost = 0
cost += head_in_stack(st, target, gold.heads) cost += head_in_stack(st, target, gold.heads)
cost += children_in_stack(st, target, gold.heads) cost += children_in_stack(st, target, gold.heads)
# If we can Break, we shouldn't push
cost += Break.is_valid(st, -1) and Break.move_cost(st, gold) == 0
return cost return cost
@ -65,15 +67,42 @@ cdef int pop_cost(const State* st, const GoldParseC* gold, int target) except -1
return cost return cost
cdef int arc_cost(const GoldParseC* gold, int head, int child, int label) except -1: cdef int arc_cost(const State* st, const GoldParseC* gold, int head, int child) except -1:
if gold.heads[child] != head: if arc_is_gold(gold, head, child):
return 0 return 0
elif gold.labels[child] == -1: elif (child + st.sent[child].head) == gold.heads[child]:
return 0
elif gold.labels[child] == label:
return 0
else:
return 1 return 1
elif gold.heads[child] >= st.i:
return 1
else:
return 0
cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) except -1:
if gold.labels[child] == -1:
return True
elif _is_gold_root(gold, head) and _is_gold_root(gold, child):
return True
elif gold.heads[child] == head:
return True
else:
return False
cdef bint label_is_gold(const GoldParseC* gold, int head, int child, int label) except -1:
if gold.labels[child] == -1:
return True
elif label == -1:
return True
elif gold.labels[child] == label:
return True
else:
return False
cdef bint _is_gold_root(const GoldParseC* gold, int word) except -1:
return gold.labels[word] == -1 or gold.heads[word] == word
cdef class Shift: cdef class Shift:
@ -96,11 +125,7 @@ cdef class Shift:
@staticmethod @staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1: cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
cdef int cost = push_cost(s, gold, s.i) return push_cost(s, gold, s.i)
# If we can break, and there's no cost to doing so, we should
if Break.is_valid(s, -1) and Break.cost(s, gold, -1) == 0:
cost += 1
return cost
@staticmethod @staticmethod
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1: cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
@ -117,7 +142,7 @@ cdef class Reduce:
@staticmethod @staticmethod
cdef int transition(State* state, int label) except -1: cdef int transition(State* state, int label) except -1:
if NON_MONOTONIC and not has_head(get_s0(state)): if NON_MONOTONIC and not has_head(get_s0(state)) and state.stack_len >= 2:
add_dep(state, state.stack[-1], state.stack[0], get_s0(state).dep) add_dep(state, state.stack[-1], state.stack[0], get_s0(state).dep)
pop_stack(state) pop_stack(state)
@ -139,7 +164,6 @@ cdef class Reduce:
return 0 return 0
cdef class LeftArc: cdef class LeftArc:
@staticmethod @staticmethod
cdef bint is_valid(const State* s, int label) except -1: cdef bint is_valid(const State* s, int label) except -1:
@ -167,31 +191,14 @@ cdef class LeftArc:
cdef int move_cost(const State* s, const GoldParseC* gold) except -1: cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
if not LeftArc.is_valid(s, -1): if not LeftArc.is_valid(s, -1):
return 9000 return 9000
cdef int cost = 0 elif arc_is_gold(gold, s.i, s.stack[0]):
if gold.heads[s.stack[0]] == s.i: return 0
return cost else:
elif at_eol(s): return pop_cost(s, gold, s.stack[0]) + arc_cost(s, gold, s.i, s.stack[0])
# Are we root?
if gold.labels[s.stack[0]] != -1:
# If we're at EOL, prefer to reduce or break over left-arc
if Reduce.is_valid(s, -1) or Break.is_valid(s, -1):
cost += gold.heads[s.stack[0]] != s.stack[0]
return cost
cost += head_in_buffer(s, s.stack[0], gold.heads)
cost += children_in_buffer(s, s.stack[0], gold.heads)
if NON_MONOTONIC and s.stack_len >= 2:
cost += gold.heads[s.stack[0]] == s.stack[-1]
if gold.labels[s.stack[0]] != -1:
cost += gold.heads[s.stack[0]] == s.stack[0]
return cost
@staticmethod @staticmethod
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1: cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
if label == -1 or gold.labels[s.stack[0]] == -1: return arc_is_gold(gold, s.i, s.stack[0]) and not label_is_gold(gold, s.i, s.stack[0], label)
return 0
if gold.heads[s.stack[0]] == s.i and label != gold.labels[s.stack[0]]:
return 1
return 0
cdef class RightArc: cdef class RightArc:
@ -212,21 +219,14 @@ cdef class RightArc:
@staticmethod @staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1: cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
return push_cost(s, gold, s.i) - (gold.heads[s.i] == s.stack[0]) if arc_is_gold(gold, s.stack[0], s.i):
return 0
else:
return push_cost(s, gold, s.i) + arc_cost(s, gold, s.stack[0], s.i)
@staticmethod @staticmethod
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1: cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
return arc_cost(gold, s.stack[0], s.i, label) return arc_is_gold(gold, s.stack[0], s.i) and not label_is_gold(gold, s.stack[0], s.i, label)
#cdef int cost = 0
#if gold.heads[s.i] == s.stack[0]:
# cost += label != -1 and label != gold.labels[s.i]
# return cost
# This indicates missing head
#if gold.labels[s.i] != -1:
# cost += head_in_buffer(s, s.i, gold.heads)
#cost += children_in_stack(s, s.i, gold.heads)
#cost += head_in_stack(s, s.i, gold.heads)
#return cost
cdef class Break: cdef class Break:
@ -237,8 +237,10 @@ cdef class Break:
return False return False
elif at_eol(s): elif at_eol(s):
return False return False
#elif NON_MONOTONIC: elif s.stack_len < 1:
# return True return False
elif NON_MONOTONIC:
return True
else: else:
# In the Break transition paper, they have this constraint that prevents # In the Break transition paper, they have this constraint that prevents
# Break if stack is disconnected. But, if we're doing non-monotonic parsing, # Break if stack is disconnected. But, if we're doing non-monotonic parsing,
@ -262,8 +264,6 @@ cdef class Break:
get_s0(state).dep = label get_s0(state).dep = label
state.stack -= 1 state.stack -= 1
state.stack_len -= 1 state.stack_len -= 1
if not at_eol(state):
push_stack(state)
@staticmethod @staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:

View File

@ -14,7 +14,7 @@ import json
from cymem.cymem cimport Pool, Address from cymem.cymem cimport Pool, Address
from murmurhash.mrmr cimport hash64 from murmurhash.mrmr cimport hash64
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t
from util import Config from util import Config
@ -34,7 +34,7 @@ from ..strings cimport StringStore
from .arc_eager cimport TransitionSystem, Transition from .arc_eager cimport TransitionSystem, Transition
from .transition_system import OracleError from .transition_system import OracleError
from ._state cimport State, new_state, copy_state, is_final, push_stack from ._state cimport State, new_state, copy_state, is_final, push_stack, get_left, get_n0
from ..gold cimport GoldParse from ..gold cimport GoldParse
from . import _parse_features from . import _parse_features
@ -83,14 +83,14 @@ cdef class Parser:
def __call__(self, Tokens tokens): def __call__(self, Tokens tokens):
if tokens.length == 0: if tokens.length == 0:
return 0 return 0
if self.cfg.get('beam_width', 1) <= 1: if self.cfg.get('beam_width', 1) < 1:
self._greedy_parse(tokens) self._greedy_parse(tokens)
else: else:
self._beam_parse(tokens) self._beam_parse(tokens)
def train(self, Tokens tokens, GoldParse gold): def train(self, Tokens tokens, GoldParse gold):
self.moves.preprocess_gold(gold) self.moves.preprocess_gold(gold)
if self.cfg.beam_width <= 1: if self.cfg.beam_width < 1:
return self._greedy_train(tokens, gold) return self._greedy_train(tokens, gold)
else: else:
return self._beam_train(tokens, gold) return self._beam_train(tokens, gold)
@ -185,8 +185,7 @@ cdef class Parser:
if follow_gold: if follow_gold:
for j in range(self.moves.n_moves): for j in range(self.moves.n_moves):
beam.is_valid[i][j] *= beam.costs[i][j] == 0 beam.is_valid[i][j] *= beam.costs[i][j] == 0
beam.advance(_transition_state, <void*>self.moves.c) beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
state = <State*>beam.at(0)
beam.check_done(_check_final_state, NULL) beam.check_done(_check_final_state, NULL)
def _count_feats(self, dict counts, Tokens tokens, list hist, int inc): def _count_feats(self, dict counts, Tokens tokens, list hist, int inc):
@ -222,3 +221,23 @@ cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
cdef int _check_final_state(void* state, void* extra_args) except -1: cdef int _check_final_state(void* state, void* extra_args) except -1:
return is_final(<State*>state) return is_final(<State*>state)
cdef hash_t _hash_state(void* _state, void* _) except 0:
state = <const State*>_state
cdef atom_t[10] rep
rep[0] = state.stack[0] if state.stack_len >= 1 else 0
rep[1] = state.stack[-1] if state.stack_len >= 2 else 0
rep[2] = state.stack[-2] if state.stack_len >= 3 else 0
rep[3] = state.i
rep[4] = state.sent[state.stack[0]].l_kids if state.stack_len >= 1 else 0
rep[5] = state.sent[state.stack[0]].r_kids if state.stack_len >= 1 else 0
rep[6] = state.sent[state.stack[0]].dep if state.stack_len >= 1 else 0
rep[7] = state.sent[state.stack[-1]].dep if state.stack_len >= 2 else 0
if get_left(state, get_n0(state), 1) != NULL:
rep[8] = get_left(state, get_n0(state), 1).dep
else:
rep[8] = 0
rep[9] = state.sent[state.i].l_kids
return hash64(rep, sizeof(atom_t) * 10, 0)