mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-14 03:26:24 +03:00
* Don't automatically push words when stack is empty, as it messes up beam parsing. Add hash method to beam state.
This commit is contained in:
parent
d51a86478e
commit
c7e3dfc1dc
|
@ -61,8 +61,8 @@ cdef int pop_stack(State *s) except -1:
|
|||
assert s.stack_len >= 1
|
||||
s.stack_len -= 1
|
||||
s.stack -= 1
|
||||
if s.stack_len == 0 and not at_eol(s):
|
||||
push_stack(s)
|
||||
#if s.stack_len == 0 and not at_eol(s):
|
||||
# push_stack(s)
|
||||
|
||||
|
||||
cdef int push_stack(State *s) except -1:
|
||||
|
@ -114,27 +114,29 @@ cdef bint has_head(const TokenC* t) nogil:
|
|||
|
||||
|
||||
cdef const TokenC* get_left(const State* s, const TokenC* head, const int idx) nogil:
|
||||
cdef uint32_t kids = head.l_kids
|
||||
if kids == 0:
|
||||
return NULL
|
||||
cdef int offset = _nth_significant_bit(kids, idx)
|
||||
cdef const TokenC* child = head - offset
|
||||
if child >= s.sent:
|
||||
return child
|
||||
else:
|
||||
return NULL
|
||||
return _new_get_left(s, head, idx)
|
||||
#cdef uint32_t kids = head.l_kids
|
||||
#if kids == 0:
|
||||
# return NULL
|
||||
#cdef int offset = _nth_significant_bit(kids, idx)
|
||||
#cdef const TokenC* child = head - offset
|
||||
#if child >= s.sent:
|
||||
# return child
|
||||
##else:
|
||||
# return NULL
|
||||
|
||||
|
||||
cdef const TokenC* get_right(const State* s, const TokenC* head, const int idx) nogil:
|
||||
cdef uint32_t kids = head.r_kids
|
||||
if kids == 0:
|
||||
return NULL
|
||||
cdef int offset = _nth_significant_bit(kids, idx)
|
||||
cdef const TokenC* child = head + offset
|
||||
if child < (s.sent + s.sent_len):
|
||||
return child
|
||||
else:
|
||||
return NULL
|
||||
return _new_get_right(s, head, idx)
|
||||
#cdef uint32_t kids = head.r_kids
|
||||
#if kids == 0:
|
||||
# return NULL
|
||||
#cdef int offset = _nth_significant_bit(kids, idx)
|
||||
#cdef const TokenC* child = head + offset
|
||||
#if child < (s.sent + s.sent_len):
|
||||
# return child
|
||||
#else:
|
||||
# return NULL
|
||||
|
||||
|
||||
cdef int count_left_kids(const TokenC* head) nogil:
|
||||
|
@ -190,7 +192,7 @@ cdef int copy_state(State* dest, const State* src) except -1:
|
|||
# From https://en.wikipedia.org/wiki/Hamming_weight
|
||||
cdef inline uint32_t _popcount(uint32_t x) nogil:
|
||||
"""Find number of non-zero bits."""
|
||||
cdef int count = 0
|
||||
cdef uint32_t count = 0
|
||||
while x != 0:
|
||||
x &= x - 1
|
||||
count += 1
|
||||
|
@ -198,10 +200,51 @@ cdef inline uint32_t _popcount(uint32_t x) nogil:
|
|||
|
||||
|
||||
cdef inline uint32_t _nth_significant_bit(uint32_t bits, int n) nogil:
|
||||
cdef int i
|
||||
cdef uint32_t i
|
||||
for i in range(32):
|
||||
if bits & (1 << i):
|
||||
n -= 1
|
||||
if n < 1:
|
||||
return i
|
||||
return 0
|
||||
|
||||
|
||||
cdef const TokenC* _new_get_left(const State* s, const TokenC* target, int idx) nogil:
|
||||
if idx < 1:
|
||||
return NULL
|
||||
cdef const TokenC* ptr = s.sent
|
||||
while ptr < target:
|
||||
# If this head is still to the right of us, we can skip to it
|
||||
# No token that's between this token and this head could be our
|
||||
# child.
|
||||
if (ptr.head >= 1) and (ptr + ptr.head) < target:
|
||||
ptr += ptr.head
|
||||
|
||||
elif ptr + ptr.head == target:
|
||||
idx -= 1
|
||||
if idx == 0:
|
||||
return ptr
|
||||
ptr += 1
|
||||
else:
|
||||
ptr += 1
|
||||
return NULL
|
||||
|
||||
|
||||
cdef const TokenC* _new_get_right(const State* s, const TokenC* target, int idx) nogil:
|
||||
if idx < 1:
|
||||
return NULL
|
||||
cdef const TokenC* ptr = s.sent + (s.sent_len - 1)
|
||||
while ptr > target:
|
||||
# If this head is still to the right of us, we can skip to it
|
||||
# No token that's between this token and this head could be our
|
||||
# child.
|
||||
if (ptr.head < 0) and ((ptr + ptr.head) > target):
|
||||
ptr += ptr.head
|
||||
elif ptr + ptr.head == target:
|
||||
idx -= 1
|
||||
if idx == 0:
|
||||
return ptr
|
||||
ptr -= 1
|
||||
else:
|
||||
ptr -= 1
|
||||
return NULL
|
||||
|
|
|
@ -55,6 +55,8 @@ cdef int push_cost(const State* st, const GoldParseC* gold, int target) except -
|
|||
cdef int cost = 0
|
||||
cost += head_in_stack(st, target, gold.heads)
|
||||
cost += children_in_stack(st, target, gold.heads)
|
||||
# If we can Break, we shouldn't push
|
||||
cost += Break.is_valid(st, -1) and Break.move_cost(st, gold) == 0
|
||||
return cost
|
||||
|
||||
|
||||
|
@ -65,15 +67,42 @@ cdef int pop_cost(const State* st, const GoldParseC* gold, int target) except -1
|
|||
return cost
|
||||
|
||||
|
||||
cdef int arc_cost(const GoldParseC* gold, int head, int child, int label) except -1:
|
||||
if gold.heads[child] != head:
|
||||
cdef int arc_cost(const State* st, const GoldParseC* gold, int head, int child) except -1:
|
||||
if arc_is_gold(gold, head, child):
|
||||
return 0
|
||||
elif gold.labels[child] == -1:
|
||||
return 0
|
||||
elif gold.labels[child] == label:
|
||||
return 0
|
||||
else:
|
||||
elif (child + st.sent[child].head) == gold.heads[child]:
|
||||
return 1
|
||||
elif gold.heads[child] >= st.i:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) except -1:
|
||||
if gold.labels[child] == -1:
|
||||
return True
|
||||
elif _is_gold_root(gold, head) and _is_gold_root(gold, child):
|
||||
return True
|
||||
elif gold.heads[child] == head:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
cdef bint label_is_gold(const GoldParseC* gold, int head, int child, int label) except -1:
|
||||
if gold.labels[child] == -1:
|
||||
return True
|
||||
elif label == -1:
|
||||
return True
|
||||
elif gold.labels[child] == label:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
cdef bint _is_gold_root(const GoldParseC* gold, int word) except -1:
|
||||
return gold.labels[word] == -1 or gold.heads[word] == word
|
||||
|
||||
|
||||
cdef class Shift:
|
||||
|
@ -96,11 +125,7 @@ cdef class Shift:
|
|||
|
||||
@staticmethod
|
||||
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||
cdef int cost = push_cost(s, gold, s.i)
|
||||
# If we can break, and there's no cost to doing so, we should
|
||||
if Break.is_valid(s, -1) and Break.cost(s, gold, -1) == 0:
|
||||
cost += 1
|
||||
return cost
|
||||
return push_cost(s, gold, s.i)
|
||||
|
||||
@staticmethod
|
||||
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
|
@ -117,7 +142,7 @@ cdef class Reduce:
|
|||
|
||||
@staticmethod
|
||||
cdef int transition(State* state, int label) except -1:
|
||||
if NON_MONOTONIC and not has_head(get_s0(state)):
|
||||
if NON_MONOTONIC and not has_head(get_s0(state)) and state.stack_len >= 2:
|
||||
add_dep(state, state.stack[-1], state.stack[0], get_s0(state).dep)
|
||||
pop_stack(state)
|
||||
|
||||
|
@ -139,7 +164,6 @@ cdef class Reduce:
|
|||
return 0
|
||||
|
||||
|
||||
|
||||
cdef class LeftArc:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const State* s, int label) except -1:
|
||||
|
@ -167,31 +191,14 @@ cdef class LeftArc:
|
|||
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||
if not LeftArc.is_valid(s, -1):
|
||||
return 9000
|
||||
cdef int cost = 0
|
||||
if gold.heads[s.stack[0]] == s.i:
|
||||
return cost
|
||||
elif at_eol(s):
|
||||
# Are we root?
|
||||
if gold.labels[s.stack[0]] != -1:
|
||||
# If we're at EOL, prefer to reduce or break over left-arc
|
||||
if Reduce.is_valid(s, -1) or Break.is_valid(s, -1):
|
||||
cost += gold.heads[s.stack[0]] != s.stack[0]
|
||||
return cost
|
||||
cost += head_in_buffer(s, s.stack[0], gold.heads)
|
||||
cost += children_in_buffer(s, s.stack[0], gold.heads)
|
||||
if NON_MONOTONIC and s.stack_len >= 2:
|
||||
cost += gold.heads[s.stack[0]] == s.stack[-1]
|
||||
if gold.labels[s.stack[0]] != -1:
|
||||
cost += gold.heads[s.stack[0]] == s.stack[0]
|
||||
return cost
|
||||
elif arc_is_gold(gold, s.i, s.stack[0]):
|
||||
return 0
|
||||
else:
|
||||
return pop_cost(s, gold, s.stack[0]) + arc_cost(s, gold, s.i, s.stack[0])
|
||||
|
||||
@staticmethod
|
||||
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
if label == -1 or gold.labels[s.stack[0]] == -1:
|
||||
return 0
|
||||
if gold.heads[s.stack[0]] == s.i and label != gold.labels[s.stack[0]]:
|
||||
return 1
|
||||
return 0
|
||||
return arc_is_gold(gold, s.i, s.stack[0]) and not label_is_gold(gold, s.i, s.stack[0], label)
|
||||
|
||||
|
||||
cdef class RightArc:
|
||||
|
@ -212,21 +219,14 @@ cdef class RightArc:
|
|||
|
||||
@staticmethod
|
||||
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||
return push_cost(s, gold, s.i) - (gold.heads[s.i] == s.stack[0])
|
||||
if arc_is_gold(gold, s.stack[0], s.i):
|
||||
return 0
|
||||
else:
|
||||
return push_cost(s, gold, s.i) + arc_cost(s, gold, s.stack[0], s.i)
|
||||
|
||||
@staticmethod
|
||||
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
return arc_cost(gold, s.stack[0], s.i, label)
|
||||
#cdef int cost = 0
|
||||
#if gold.heads[s.i] == s.stack[0]:
|
||||
# cost += label != -1 and label != gold.labels[s.i]
|
||||
# return cost
|
||||
# This indicates missing head
|
||||
#if gold.labels[s.i] != -1:
|
||||
# cost += head_in_buffer(s, s.i, gold.heads)
|
||||
#cost += children_in_stack(s, s.i, gold.heads)
|
||||
#cost += head_in_stack(s, s.i, gold.heads)
|
||||
#return cost
|
||||
return arc_is_gold(gold, s.stack[0], s.i) and not label_is_gold(gold, s.stack[0], s.i, label)
|
||||
|
||||
|
||||
cdef class Break:
|
||||
|
@ -237,8 +237,10 @@ cdef class Break:
|
|||
return False
|
||||
elif at_eol(s):
|
||||
return False
|
||||
#elif NON_MONOTONIC:
|
||||
# return True
|
||||
elif s.stack_len < 1:
|
||||
return False
|
||||
elif NON_MONOTONIC:
|
||||
return True
|
||||
else:
|
||||
# In the Break transition paper, they have this constraint that prevents
|
||||
# Break if stack is disconnected. But, if we're doing non-monotonic parsing,
|
||||
|
@ -262,8 +264,6 @@ cdef class Break:
|
|||
get_s0(state).dep = label
|
||||
state.stack -= 1
|
||||
state.stack_len -= 1
|
||||
if not at_eol(state):
|
||||
push_stack(state)
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
|
|
|
@ -14,7 +14,7 @@ import json
|
|||
|
||||
from cymem.cymem cimport Pool, Address
|
||||
from murmurhash.mrmr cimport hash64
|
||||
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t
|
||||
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t
|
||||
|
||||
|
||||
from util import Config
|
||||
|
@ -34,7 +34,7 @@ from ..strings cimport StringStore
|
|||
from .arc_eager cimport TransitionSystem, Transition
|
||||
from .transition_system import OracleError
|
||||
|
||||
from ._state cimport State, new_state, copy_state, is_final, push_stack
|
||||
from ._state cimport State, new_state, copy_state, is_final, push_stack, get_left, get_n0
|
||||
from ..gold cimport GoldParse
|
||||
|
||||
from . import _parse_features
|
||||
|
@ -83,14 +83,14 @@ cdef class Parser:
|
|||
def __call__(self, Tokens tokens):
|
||||
if tokens.length == 0:
|
||||
return 0
|
||||
if self.cfg.get('beam_width', 1) <= 1:
|
||||
if self.cfg.get('beam_width', 1) < 1:
|
||||
self._greedy_parse(tokens)
|
||||
else:
|
||||
self._beam_parse(tokens)
|
||||
|
||||
def train(self, Tokens tokens, GoldParse gold):
|
||||
self.moves.preprocess_gold(gold)
|
||||
if self.cfg.beam_width <= 1:
|
||||
if self.cfg.beam_width < 1:
|
||||
return self._greedy_train(tokens, gold)
|
||||
else:
|
||||
return self._beam_train(tokens, gold)
|
||||
|
@ -185,8 +185,7 @@ cdef class Parser:
|
|||
if follow_gold:
|
||||
for j in range(self.moves.n_moves):
|
||||
beam.is_valid[i][j] *= beam.costs[i][j] == 0
|
||||
beam.advance(_transition_state, <void*>self.moves.c)
|
||||
state = <State*>beam.at(0)
|
||||
beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
|
||||
beam.check_done(_check_final_state, NULL)
|
||||
|
||||
def _count_feats(self, dict counts, Tokens tokens, list hist, int inc):
|
||||
|
@ -222,3 +221,23 @@ cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
|
|||
|
||||
cdef int _check_final_state(void* state, void* extra_args) except -1:
|
||||
return is_final(<State*>state)
|
||||
|
||||
|
||||
cdef hash_t _hash_state(void* _state, void* _) except 0:
|
||||
state = <const State*>_state
|
||||
cdef atom_t[10] rep
|
||||
|
||||
rep[0] = state.stack[0] if state.stack_len >= 1 else 0
|
||||
rep[1] = state.stack[-1] if state.stack_len >= 2 else 0
|
||||
rep[2] = state.stack[-2] if state.stack_len >= 3 else 0
|
||||
rep[3] = state.i
|
||||
rep[4] = state.sent[state.stack[0]].l_kids if state.stack_len >= 1 else 0
|
||||
rep[5] = state.sent[state.stack[0]].r_kids if state.stack_len >= 1 else 0
|
||||
rep[6] = state.sent[state.stack[0]].dep if state.stack_len >= 1 else 0
|
||||
rep[7] = state.sent[state.stack[-1]].dep if state.stack_len >= 2 else 0
|
||||
if get_left(state, get_n0(state), 1) != NULL:
|
||||
rep[8] = get_left(state, get_n0(state), 1).dep
|
||||
else:
|
||||
rep[8] = 0
|
||||
rep[9] = state.sent[state.i].l_kids
|
||||
return hash64(rep, sizeof(atom_t) * 10, 0)
|
||||
|
|
Loading…
Reference in New Issue
Block a user