* Don't automatically push words when stack is empty, as it messes up beam parsing. Add hash method to beam state.

This commit is contained in:
Matthew Honnibal 2015-06-08 14:49:04 +02:00
parent d51a86478e
commit c7e3dfc1dc
3 changed files with 142 additions and 80 deletions

View File

@ -61,8 +61,8 @@ cdef int pop_stack(State *s) except -1:
assert s.stack_len >= 1
s.stack_len -= 1
s.stack -= 1
if s.stack_len == 0 and not at_eol(s):
push_stack(s)
#if s.stack_len == 0 and not at_eol(s):
# push_stack(s)
cdef int push_stack(State *s) except -1:
@ -114,27 +114,29 @@ cdef bint has_head(const TokenC* t) nogil:
cdef const TokenC* get_left(const State* s, const TokenC* head, const int idx) nogil:
cdef uint32_t kids = head.l_kids
if kids == 0:
return NULL
cdef int offset = _nth_significant_bit(kids, idx)
cdef const TokenC* child = head - offset
if child >= s.sent:
return child
else:
return NULL
return _new_get_left(s, head, idx)
#cdef uint32_t kids = head.l_kids
#if kids == 0:
# return NULL
#cdef int offset = _nth_significant_bit(kids, idx)
#cdef const TokenC* child = head - offset
#if child >= s.sent:
# return child
##else:
# return NULL
cdef const TokenC* get_right(const State* s, const TokenC* head, const int idx) nogil:
cdef uint32_t kids = head.r_kids
if kids == 0:
return NULL
cdef int offset = _nth_significant_bit(kids, idx)
cdef const TokenC* child = head + offset
if child < (s.sent + s.sent_len):
return child
else:
return NULL
return _new_get_right(s, head, idx)
#cdef uint32_t kids = head.r_kids
#if kids == 0:
# return NULL
#cdef int offset = _nth_significant_bit(kids, idx)
#cdef const TokenC* child = head + offset
#if child < (s.sent + s.sent_len):
# return child
#else:
# return NULL
cdef int count_left_kids(const TokenC* head) nogil:
@ -190,7 +192,7 @@ cdef int copy_state(State* dest, const State* src) except -1:
# From https://en.wikipedia.org/wiki/Hamming_weight
cdef inline uint32_t _popcount(uint32_t x) nogil:
"""Find number of non-zero bits."""
cdef int count = 0
cdef uint32_t count = 0
while x != 0:
x &= x - 1
count += 1
@ -198,10 +200,51 @@ cdef inline uint32_t _popcount(uint32_t x) nogil:
cdef inline uint32_t _nth_significant_bit(uint32_t bits, int n) nogil:
cdef int i
cdef uint32_t i
for i in range(32):
if bits & (1 << i):
n -= 1
if n < 1:
return i
return 0
cdef const TokenC* _new_get_left(const State* s, const TokenC* target, int idx) nogil:
if idx < 1:
return NULL
cdef const TokenC* ptr = s.sent
while ptr < target:
# If this head is still to the right of us, we can skip to it
# No token that's between this token and this head could be our
# child.
if (ptr.head >= 1) and (ptr + ptr.head) < target:
ptr += ptr.head
elif ptr + ptr.head == target:
idx -= 1
if idx == 0:
return ptr
ptr += 1
else:
ptr += 1
return NULL
cdef const TokenC* _new_get_right(const State* s, const TokenC* target, int idx) nogil:
if idx < 1:
return NULL
cdef const TokenC* ptr = s.sent + (s.sent_len - 1)
while ptr > target:
# If this head is still to the right of us, we can skip to it
# No token that's between this token and this head could be our
# child.
if (ptr.head < 0) and ((ptr + ptr.head) > target):
ptr += ptr.head
elif ptr + ptr.head == target:
idx -= 1
if idx == 0:
return ptr
ptr -= 1
else:
ptr -= 1
return NULL

View File

@ -55,6 +55,8 @@ cdef int push_cost(const State* st, const GoldParseC* gold, int target) except -
cdef int cost = 0
cost += head_in_stack(st, target, gold.heads)
cost += children_in_stack(st, target, gold.heads)
# If we can Break, we shouldn't push
cost += Break.is_valid(st, -1) and Break.move_cost(st, gold) == 0
return cost
@ -65,15 +67,42 @@ cdef int pop_cost(const State* st, const GoldParseC* gold, int target) except -1
return cost
cdef int arc_cost(const GoldParseC* gold, int head, int child, int label) except -1:
if gold.heads[child] != head:
cdef int arc_cost(const State* st, const GoldParseC* gold, int head, int child) except -1:
if arc_is_gold(gold, head, child):
return 0
elif gold.labels[child] == -1:
return 0
elif gold.labels[child] == label:
return 0
else:
elif (child + st.sent[child].head) == gold.heads[child]:
return 1
elif gold.heads[child] >= st.i:
return 1
else:
return 0
cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) except -1:
if gold.labels[child] == -1:
return True
elif _is_gold_root(gold, head) and _is_gold_root(gold, child):
return True
elif gold.heads[child] == head:
return True
else:
return False
cdef bint label_is_gold(const GoldParseC* gold, int head, int child, int label) except -1:
if gold.labels[child] == -1:
return True
elif label == -1:
return True
elif gold.labels[child] == label:
return True
else:
return False
cdef bint _is_gold_root(const GoldParseC* gold, int word) except -1:
return gold.labels[word] == -1 or gold.heads[word] == word
cdef class Shift:
@ -96,11 +125,7 @@ cdef class Shift:
@staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
cdef int cost = push_cost(s, gold, s.i)
# If we can break, and there's no cost to doing so, we should
if Break.is_valid(s, -1) and Break.cost(s, gold, -1) == 0:
cost += 1
return cost
return push_cost(s, gold, s.i)
@staticmethod
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
@ -117,7 +142,7 @@ cdef class Reduce:
@staticmethod
cdef int transition(State* state, int label) except -1:
if NON_MONOTONIC and not has_head(get_s0(state)):
if NON_MONOTONIC and not has_head(get_s0(state)) and state.stack_len >= 2:
add_dep(state, state.stack[-1], state.stack[0], get_s0(state).dep)
pop_stack(state)
@ -139,7 +164,6 @@ cdef class Reduce:
return 0
cdef class LeftArc:
@staticmethod
cdef bint is_valid(const State* s, int label) except -1:
@ -167,31 +191,14 @@ cdef class LeftArc:
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
if not LeftArc.is_valid(s, -1):
return 9000
cdef int cost = 0
if gold.heads[s.stack[0]] == s.i:
return cost
elif at_eol(s):
# Are we root?
if gold.labels[s.stack[0]] != -1:
# If we're at EOL, prefer to reduce or break over left-arc
if Reduce.is_valid(s, -1) or Break.is_valid(s, -1):
cost += gold.heads[s.stack[0]] != s.stack[0]
return cost
cost += head_in_buffer(s, s.stack[0], gold.heads)
cost += children_in_buffer(s, s.stack[0], gold.heads)
if NON_MONOTONIC and s.stack_len >= 2:
cost += gold.heads[s.stack[0]] == s.stack[-1]
if gold.labels[s.stack[0]] != -1:
cost += gold.heads[s.stack[0]] == s.stack[0]
return cost
elif arc_is_gold(gold, s.i, s.stack[0]):
return 0
else:
return pop_cost(s, gold, s.stack[0]) + arc_cost(s, gold, s.i, s.stack[0])
@staticmethod
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
if label == -1 or gold.labels[s.stack[0]] == -1:
return 0
if gold.heads[s.stack[0]] == s.i and label != gold.labels[s.stack[0]]:
return 1
return 0
return arc_is_gold(gold, s.i, s.stack[0]) and not label_is_gold(gold, s.i, s.stack[0], label)
cdef class RightArc:
@ -212,21 +219,14 @@ cdef class RightArc:
@staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
return push_cost(s, gold, s.i) - (gold.heads[s.i] == s.stack[0])
if arc_is_gold(gold, s.stack[0], s.i):
return 0
else:
return push_cost(s, gold, s.i) + arc_cost(s, gold, s.stack[0], s.i)
@staticmethod
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
return arc_cost(gold, s.stack[0], s.i, label)
#cdef int cost = 0
#if gold.heads[s.i] == s.stack[0]:
# cost += label != -1 and label != gold.labels[s.i]
# return cost
# This indicates missing head
#if gold.labels[s.i] != -1:
# cost += head_in_buffer(s, s.i, gold.heads)
#cost += children_in_stack(s, s.i, gold.heads)
#cost += head_in_stack(s, s.i, gold.heads)
#return cost
return arc_is_gold(gold, s.stack[0], s.i) and not label_is_gold(gold, s.stack[0], s.i, label)
cdef class Break:
@ -237,8 +237,10 @@ cdef class Break:
return False
elif at_eol(s):
return False
#elif NON_MONOTONIC:
# return True
elif s.stack_len < 1:
return False
elif NON_MONOTONIC:
return True
else:
# In the Break transition paper, they have this constraint that prevents
# Break if stack is disconnected. But, if we're doing non-monotonic parsing,
@ -262,8 +264,6 @@ cdef class Break:
get_s0(state).dep = label
state.stack -= 1
state.stack_len -= 1
if not at_eol(state):
push_stack(state)
@staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:

View File

@ -14,7 +14,7 @@ import json
from cymem.cymem cimport Pool, Address
from murmurhash.mrmr cimport hash64
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t
from util import Config
@ -34,7 +34,7 @@ from ..strings cimport StringStore
from .arc_eager cimport TransitionSystem, Transition
from .transition_system import OracleError
from ._state cimport State, new_state, copy_state, is_final, push_stack
from ._state cimport State, new_state, copy_state, is_final, push_stack, get_left, get_n0
from ..gold cimport GoldParse
from . import _parse_features
@ -83,14 +83,14 @@ cdef class Parser:
def __call__(self, Tokens tokens):
if tokens.length == 0:
return 0
if self.cfg.get('beam_width', 1) <= 1:
if self.cfg.get('beam_width', 1) < 1:
self._greedy_parse(tokens)
else:
self._beam_parse(tokens)
def train(self, Tokens tokens, GoldParse gold):
self.moves.preprocess_gold(gold)
if self.cfg.beam_width <= 1:
if self.cfg.beam_width < 1:
return self._greedy_train(tokens, gold)
else:
return self._beam_train(tokens, gold)
@ -185,8 +185,7 @@ cdef class Parser:
if follow_gold:
for j in range(self.moves.n_moves):
beam.is_valid[i][j] *= beam.costs[i][j] == 0
beam.advance(_transition_state, <void*>self.moves.c)
state = <State*>beam.at(0)
beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
beam.check_done(_check_final_state, NULL)
def _count_feats(self, dict counts, Tokens tokens, list hist, int inc):
@ -222,3 +221,23 @@ cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
cdef int _check_final_state(void* state, void* extra_args) except -1:
return is_final(<State*>state)
cdef hash_t _hash_state(void* _state, void* _) except 0:
state = <const State*>_state
cdef atom_t[10] rep
rep[0] = state.stack[0] if state.stack_len >= 1 else 0
rep[1] = state.stack[-1] if state.stack_len >= 2 else 0
rep[2] = state.stack[-2] if state.stack_len >= 3 else 0
rep[3] = state.i
rep[4] = state.sent[state.stack[0]].l_kids if state.stack_len >= 1 else 0
rep[5] = state.sent[state.stack[0]].r_kids if state.stack_len >= 1 else 0
rep[6] = state.sent[state.stack[0]].dep if state.stack_len >= 1 else 0
rep[7] = state.sent[state.stack[-1]].dep if state.stack_len >= 2 else 0
if get_left(state, get_n0(state), 1) != NULL:
rep[8] = get_left(state, get_n0(state), 1).dep
else:
rep[8] = 0
rep[9] = state.sent[state.i].l_kids
return hash64(rep, sizeof(atom_t) * 10, 0)