mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-27 18:36:36 +03:00
* Don't automatically push words when stack is empty, as it messes up beam parsing. Add hash method to beam state.
This commit is contained in:
parent
d51a86478e
commit
c7e3dfc1dc
|
@ -61,8 +61,8 @@ cdef int pop_stack(State *s) except -1:
|
||||||
assert s.stack_len >= 1
|
assert s.stack_len >= 1
|
||||||
s.stack_len -= 1
|
s.stack_len -= 1
|
||||||
s.stack -= 1
|
s.stack -= 1
|
||||||
if s.stack_len == 0 and not at_eol(s):
|
#if s.stack_len == 0 and not at_eol(s):
|
||||||
push_stack(s)
|
# push_stack(s)
|
||||||
|
|
||||||
|
|
||||||
cdef int push_stack(State *s) except -1:
|
cdef int push_stack(State *s) except -1:
|
||||||
|
@ -114,27 +114,29 @@ cdef bint has_head(const TokenC* t) nogil:
|
||||||
|
|
||||||
|
|
||||||
cdef const TokenC* get_left(const State* s, const TokenC* head, const int idx) nogil:
|
cdef const TokenC* get_left(const State* s, const TokenC* head, const int idx) nogil:
|
||||||
cdef uint32_t kids = head.l_kids
|
return _new_get_left(s, head, idx)
|
||||||
if kids == 0:
|
#cdef uint32_t kids = head.l_kids
|
||||||
return NULL
|
#if kids == 0:
|
||||||
cdef int offset = _nth_significant_bit(kids, idx)
|
# return NULL
|
||||||
cdef const TokenC* child = head - offset
|
#cdef int offset = _nth_significant_bit(kids, idx)
|
||||||
if child >= s.sent:
|
#cdef const TokenC* child = head - offset
|
||||||
return child
|
#if child >= s.sent:
|
||||||
else:
|
# return child
|
||||||
return NULL
|
##else:
|
||||||
|
# return NULL
|
||||||
|
|
||||||
|
|
||||||
cdef const TokenC* get_right(const State* s, const TokenC* head, const int idx) nogil:
|
cdef const TokenC* get_right(const State* s, const TokenC* head, const int idx) nogil:
|
||||||
cdef uint32_t kids = head.r_kids
|
return _new_get_right(s, head, idx)
|
||||||
if kids == 0:
|
#cdef uint32_t kids = head.r_kids
|
||||||
return NULL
|
#if kids == 0:
|
||||||
cdef int offset = _nth_significant_bit(kids, idx)
|
# return NULL
|
||||||
cdef const TokenC* child = head + offset
|
#cdef int offset = _nth_significant_bit(kids, idx)
|
||||||
if child < (s.sent + s.sent_len):
|
#cdef const TokenC* child = head + offset
|
||||||
return child
|
#if child < (s.sent + s.sent_len):
|
||||||
else:
|
# return child
|
||||||
return NULL
|
#else:
|
||||||
|
# return NULL
|
||||||
|
|
||||||
|
|
||||||
cdef int count_left_kids(const TokenC* head) nogil:
|
cdef int count_left_kids(const TokenC* head) nogil:
|
||||||
|
@ -190,7 +192,7 @@ cdef int copy_state(State* dest, const State* src) except -1:
|
||||||
# From https://en.wikipedia.org/wiki/Hamming_weight
|
# From https://en.wikipedia.org/wiki/Hamming_weight
|
||||||
cdef inline uint32_t _popcount(uint32_t x) nogil:
|
cdef inline uint32_t _popcount(uint32_t x) nogil:
|
||||||
"""Find number of non-zero bits."""
|
"""Find number of non-zero bits."""
|
||||||
cdef int count = 0
|
cdef uint32_t count = 0
|
||||||
while x != 0:
|
while x != 0:
|
||||||
x &= x - 1
|
x &= x - 1
|
||||||
count += 1
|
count += 1
|
||||||
|
@ -198,10 +200,51 @@ cdef inline uint32_t _popcount(uint32_t x) nogil:
|
||||||
|
|
||||||
|
|
||||||
cdef inline uint32_t _nth_significant_bit(uint32_t bits, int n) nogil:
|
cdef inline uint32_t _nth_significant_bit(uint32_t bits, int n) nogil:
|
||||||
cdef int i
|
cdef uint32_t i
|
||||||
for i in range(32):
|
for i in range(32):
|
||||||
if bits & (1 << i):
|
if bits & (1 << i):
|
||||||
n -= 1
|
n -= 1
|
||||||
if n < 1:
|
if n < 1:
|
||||||
return i
|
return i
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
cdef const TokenC* _new_get_left(const State* s, const TokenC* target, int idx) nogil:
|
||||||
|
if idx < 1:
|
||||||
|
return NULL
|
||||||
|
cdef const TokenC* ptr = s.sent
|
||||||
|
while ptr < target:
|
||||||
|
# If this head is still to the right of us, we can skip to it
|
||||||
|
# No token that's between this token and this head could be our
|
||||||
|
# child.
|
||||||
|
if (ptr.head >= 1) and (ptr + ptr.head) < target:
|
||||||
|
ptr += ptr.head
|
||||||
|
|
||||||
|
elif ptr + ptr.head == target:
|
||||||
|
idx -= 1
|
||||||
|
if idx == 0:
|
||||||
|
return ptr
|
||||||
|
ptr += 1
|
||||||
|
else:
|
||||||
|
ptr += 1
|
||||||
|
return NULL
|
||||||
|
|
||||||
|
|
||||||
|
cdef const TokenC* _new_get_right(const State* s, const TokenC* target, int idx) nogil:
|
||||||
|
if idx < 1:
|
||||||
|
return NULL
|
||||||
|
cdef const TokenC* ptr = s.sent + (s.sent_len - 1)
|
||||||
|
while ptr > target:
|
||||||
|
# If this head is still to the right of us, we can skip to it
|
||||||
|
# No token that's between this token and this head could be our
|
||||||
|
# child.
|
||||||
|
if (ptr.head < 0) and ((ptr + ptr.head) > target):
|
||||||
|
ptr += ptr.head
|
||||||
|
elif ptr + ptr.head == target:
|
||||||
|
idx -= 1
|
||||||
|
if idx == 0:
|
||||||
|
return ptr
|
||||||
|
ptr -= 1
|
||||||
|
else:
|
||||||
|
ptr -= 1
|
||||||
|
return NULL
|
||||||
|
|
|
@ -55,6 +55,8 @@ cdef int push_cost(const State* st, const GoldParseC* gold, int target) except -
|
||||||
cdef int cost = 0
|
cdef int cost = 0
|
||||||
cost += head_in_stack(st, target, gold.heads)
|
cost += head_in_stack(st, target, gold.heads)
|
||||||
cost += children_in_stack(st, target, gold.heads)
|
cost += children_in_stack(st, target, gold.heads)
|
||||||
|
# If we can Break, we shouldn't push
|
||||||
|
cost += Break.is_valid(st, -1) and Break.move_cost(st, gold) == 0
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
|
|
||||||
|
@ -65,15 +67,42 @@ cdef int pop_cost(const State* st, const GoldParseC* gold, int target) except -1
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
|
|
||||||
cdef int arc_cost(const GoldParseC* gold, int head, int child, int label) except -1:
|
cdef int arc_cost(const State* st, const GoldParseC* gold, int head, int child) except -1:
|
||||||
if gold.heads[child] != head:
|
if arc_is_gold(gold, head, child):
|
||||||
return 0
|
return 0
|
||||||
elif gold.labels[child] == -1:
|
elif (child + st.sent[child].head) == gold.heads[child]:
|
||||||
return 0
|
|
||||||
elif gold.labels[child] == label:
|
|
||||||
return 0
|
|
||||||
else:
|
|
||||||
return 1
|
return 1
|
||||||
|
elif gold.heads[child] >= st.i:
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) except -1:
|
||||||
|
if gold.labels[child] == -1:
|
||||||
|
return True
|
||||||
|
elif _is_gold_root(gold, head) and _is_gold_root(gold, child):
|
||||||
|
return True
|
||||||
|
elif gold.heads[child] == head:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
cdef bint label_is_gold(const GoldParseC* gold, int head, int child, int label) except -1:
|
||||||
|
if gold.labels[child] == -1:
|
||||||
|
return True
|
||||||
|
elif label == -1:
|
||||||
|
return True
|
||||||
|
elif gold.labels[child] == label:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
cdef bint _is_gold_root(const GoldParseC* gold, int word) except -1:
|
||||||
|
return gold.labels[word] == -1 or gold.heads[word] == word
|
||||||
|
|
||||||
|
|
||||||
cdef class Shift:
|
cdef class Shift:
|
||||||
|
@ -96,11 +125,7 @@ cdef class Shift:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||||
cdef int cost = push_cost(s, gold, s.i)
|
return push_cost(s, gold, s.i)
|
||||||
# If we can break, and there's no cost to doing so, we should
|
|
||||||
if Break.is_valid(s, -1) and Break.cost(s, gold, -1) == 0:
|
|
||||||
cost += 1
|
|
||||||
return cost
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||||
|
@ -117,7 +142,7 @@ cdef class Reduce:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(State* state, int label) except -1:
|
cdef int transition(State* state, int label) except -1:
|
||||||
if NON_MONOTONIC and not has_head(get_s0(state)):
|
if NON_MONOTONIC and not has_head(get_s0(state)) and state.stack_len >= 2:
|
||||||
add_dep(state, state.stack[-1], state.stack[0], get_s0(state).dep)
|
add_dep(state, state.stack[-1], state.stack[0], get_s0(state).dep)
|
||||||
pop_stack(state)
|
pop_stack(state)
|
||||||
|
|
||||||
|
@ -139,7 +164,6 @@ cdef class Reduce:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
cdef class LeftArc:
|
cdef class LeftArc:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef bint is_valid(const State* s, int label) except -1:
|
cdef bint is_valid(const State* s, int label) except -1:
|
||||||
|
@ -167,31 +191,14 @@ cdef class LeftArc:
|
||||||
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||||
if not LeftArc.is_valid(s, -1):
|
if not LeftArc.is_valid(s, -1):
|
||||||
return 9000
|
return 9000
|
||||||
cdef int cost = 0
|
elif arc_is_gold(gold, s.i, s.stack[0]):
|
||||||
if gold.heads[s.stack[0]] == s.i:
|
return 0
|
||||||
return cost
|
else:
|
||||||
elif at_eol(s):
|
return pop_cost(s, gold, s.stack[0]) + arc_cost(s, gold, s.i, s.stack[0])
|
||||||
# Are we root?
|
|
||||||
if gold.labels[s.stack[0]] != -1:
|
|
||||||
# If we're at EOL, prefer to reduce or break over left-arc
|
|
||||||
if Reduce.is_valid(s, -1) or Break.is_valid(s, -1):
|
|
||||||
cost += gold.heads[s.stack[0]] != s.stack[0]
|
|
||||||
return cost
|
|
||||||
cost += head_in_buffer(s, s.stack[0], gold.heads)
|
|
||||||
cost += children_in_buffer(s, s.stack[0], gold.heads)
|
|
||||||
if NON_MONOTONIC and s.stack_len >= 2:
|
|
||||||
cost += gold.heads[s.stack[0]] == s.stack[-1]
|
|
||||||
if gold.labels[s.stack[0]] != -1:
|
|
||||||
cost += gold.heads[s.stack[0]] == s.stack[0]
|
|
||||||
return cost
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||||
if label == -1 or gold.labels[s.stack[0]] == -1:
|
return arc_is_gold(gold, s.i, s.stack[0]) and not label_is_gold(gold, s.i, s.stack[0], label)
|
||||||
return 0
|
|
||||||
if gold.heads[s.stack[0]] == s.i and label != gold.labels[s.stack[0]]:
|
|
||||||
return 1
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
cdef class RightArc:
|
cdef class RightArc:
|
||||||
|
@ -212,21 +219,14 @@ cdef class RightArc:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||||
return push_cost(s, gold, s.i) - (gold.heads[s.i] == s.stack[0])
|
if arc_is_gold(gold, s.stack[0], s.i):
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return push_cost(s, gold, s.i) + arc_cost(s, gold, s.stack[0], s.i)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||||
return arc_cost(gold, s.stack[0], s.i, label)
|
return arc_is_gold(gold, s.stack[0], s.i) and not label_is_gold(gold, s.stack[0], s.i, label)
|
||||||
#cdef int cost = 0
|
|
||||||
#if gold.heads[s.i] == s.stack[0]:
|
|
||||||
# cost += label != -1 and label != gold.labels[s.i]
|
|
||||||
# return cost
|
|
||||||
# This indicates missing head
|
|
||||||
#if gold.labels[s.i] != -1:
|
|
||||||
# cost += head_in_buffer(s, s.i, gold.heads)
|
|
||||||
#cost += children_in_stack(s, s.i, gold.heads)
|
|
||||||
#cost += head_in_stack(s, s.i, gold.heads)
|
|
||||||
#return cost
|
|
||||||
|
|
||||||
|
|
||||||
cdef class Break:
|
cdef class Break:
|
||||||
|
@ -237,8 +237,10 @@ cdef class Break:
|
||||||
return False
|
return False
|
||||||
elif at_eol(s):
|
elif at_eol(s):
|
||||||
return False
|
return False
|
||||||
#elif NON_MONOTONIC:
|
elif s.stack_len < 1:
|
||||||
# return True
|
return False
|
||||||
|
elif NON_MONOTONIC:
|
||||||
|
return True
|
||||||
else:
|
else:
|
||||||
# In the Break transition paper, they have this constraint that prevents
|
# In the Break transition paper, they have this constraint that prevents
|
||||||
# Break if stack is disconnected. But, if we're doing non-monotonic parsing,
|
# Break if stack is disconnected. But, if we're doing non-monotonic parsing,
|
||||||
|
@ -262,8 +264,6 @@ cdef class Break:
|
||||||
get_s0(state).dep = label
|
get_s0(state).dep = label
|
||||||
state.stack -= 1
|
state.stack -= 1
|
||||||
state.stack_len -= 1
|
state.stack_len -= 1
|
||||||
if not at_eol(state):
|
|
||||||
push_stack(state)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||||
|
|
|
@ -14,7 +14,7 @@ import json
|
||||||
|
|
||||||
from cymem.cymem cimport Pool, Address
|
from cymem.cymem cimport Pool, Address
|
||||||
from murmurhash.mrmr cimport hash64
|
from murmurhash.mrmr cimport hash64
|
||||||
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t
|
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t
|
||||||
|
|
||||||
|
|
||||||
from util import Config
|
from util import Config
|
||||||
|
@ -34,7 +34,7 @@ from ..strings cimport StringStore
|
||||||
from .arc_eager cimport TransitionSystem, Transition
|
from .arc_eager cimport TransitionSystem, Transition
|
||||||
from .transition_system import OracleError
|
from .transition_system import OracleError
|
||||||
|
|
||||||
from ._state cimport State, new_state, copy_state, is_final, push_stack
|
from ._state cimport State, new_state, copy_state, is_final, push_stack, get_left, get_n0
|
||||||
from ..gold cimport GoldParse
|
from ..gold cimport GoldParse
|
||||||
|
|
||||||
from . import _parse_features
|
from . import _parse_features
|
||||||
|
@ -83,14 +83,14 @@ cdef class Parser:
|
||||||
def __call__(self, Tokens tokens):
|
def __call__(self, Tokens tokens):
|
||||||
if tokens.length == 0:
|
if tokens.length == 0:
|
||||||
return 0
|
return 0
|
||||||
if self.cfg.get('beam_width', 1) <= 1:
|
if self.cfg.get('beam_width', 1) < 1:
|
||||||
self._greedy_parse(tokens)
|
self._greedy_parse(tokens)
|
||||||
else:
|
else:
|
||||||
self._beam_parse(tokens)
|
self._beam_parse(tokens)
|
||||||
|
|
||||||
def train(self, Tokens tokens, GoldParse gold):
|
def train(self, Tokens tokens, GoldParse gold):
|
||||||
self.moves.preprocess_gold(gold)
|
self.moves.preprocess_gold(gold)
|
||||||
if self.cfg.beam_width <= 1:
|
if self.cfg.beam_width < 1:
|
||||||
return self._greedy_train(tokens, gold)
|
return self._greedy_train(tokens, gold)
|
||||||
else:
|
else:
|
||||||
return self._beam_train(tokens, gold)
|
return self._beam_train(tokens, gold)
|
||||||
|
@ -185,8 +185,7 @@ cdef class Parser:
|
||||||
if follow_gold:
|
if follow_gold:
|
||||||
for j in range(self.moves.n_moves):
|
for j in range(self.moves.n_moves):
|
||||||
beam.is_valid[i][j] *= beam.costs[i][j] == 0
|
beam.is_valid[i][j] *= beam.costs[i][j] == 0
|
||||||
beam.advance(_transition_state, <void*>self.moves.c)
|
beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
|
||||||
state = <State*>beam.at(0)
|
|
||||||
beam.check_done(_check_final_state, NULL)
|
beam.check_done(_check_final_state, NULL)
|
||||||
|
|
||||||
def _count_feats(self, dict counts, Tokens tokens, list hist, int inc):
|
def _count_feats(self, dict counts, Tokens tokens, list hist, int inc):
|
||||||
|
@ -222,3 +221,23 @@ cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
|
||||||
|
|
||||||
cdef int _check_final_state(void* state, void* extra_args) except -1:
|
cdef int _check_final_state(void* state, void* extra_args) except -1:
|
||||||
return is_final(<State*>state)
|
return is_final(<State*>state)
|
||||||
|
|
||||||
|
|
||||||
|
cdef hash_t _hash_state(void* _state, void* _) except 0:
|
||||||
|
state = <const State*>_state
|
||||||
|
cdef atom_t[10] rep
|
||||||
|
|
||||||
|
rep[0] = state.stack[0] if state.stack_len >= 1 else 0
|
||||||
|
rep[1] = state.stack[-1] if state.stack_len >= 2 else 0
|
||||||
|
rep[2] = state.stack[-2] if state.stack_len >= 3 else 0
|
||||||
|
rep[3] = state.i
|
||||||
|
rep[4] = state.sent[state.stack[0]].l_kids if state.stack_len >= 1 else 0
|
||||||
|
rep[5] = state.sent[state.stack[0]].r_kids if state.stack_len >= 1 else 0
|
||||||
|
rep[6] = state.sent[state.stack[0]].dep if state.stack_len >= 1 else 0
|
||||||
|
rep[7] = state.sent[state.stack[-1]].dep if state.stack_len >= 2 else 0
|
||||||
|
if get_left(state, get_n0(state), 1) != NULL:
|
||||||
|
rep[8] = get_left(state, get_n0(state), 1).dep
|
||||||
|
else:
|
||||||
|
rep[8] = 0
|
||||||
|
rep[9] = state.sent[state.i].l_kids
|
||||||
|
return hash64(rep, sizeof(atom_t) * 10, 0)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user