mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 09:56:28 +03:00
Record negative costs in non-monotonic arc eager oracle
This commit is contained in:
parent
ecf91a2dbb
commit
d11f1a4ddf
|
@ -1,3 +1,6 @@
|
||||||
|
# cython: profile=True
|
||||||
|
# cython: cdivision=True
|
||||||
|
# cython: infer_types=True
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import ctypes
|
import ctypes
|
||||||
|
@ -155,7 +158,18 @@ cdef class Reduce:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef inline weight_t move_cost(StateClass st, const GoldParseC* gold) nogil:
|
cdef inline weight_t move_cost(StateClass st, const GoldParseC* gold) nogil:
|
||||||
return pop_cost(st, gold, st.S(0))
|
cost = pop_cost(st, gold, st.S(0))
|
||||||
|
if not st.has_head(st.S(0)):
|
||||||
|
# Decrement cost for the arcs e save
|
||||||
|
for i in range(1, st.stack_depth()):
|
||||||
|
S_i = st.S(i)
|
||||||
|
if gold.heads[st.S(0)] == S_i:
|
||||||
|
cost -= 1
|
||||||
|
if gold.heads[S_i] == st.S(0):
|
||||||
|
cost -= 1
|
||||||
|
if Break.is_valid(st.c, -1) and Break.move_cost(st, gold) == 0:
|
||||||
|
cost -= 1
|
||||||
|
return cost
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||||
|
@ -181,7 +195,8 @@ cdef class LeftArc:
|
||||||
cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
|
cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
|
||||||
cdef weight_t cost = 0
|
cdef weight_t cost = 0
|
||||||
if arc_is_gold(gold, s.B(0), s.S(0)):
|
if arc_is_gold(gold, s.B(0), s.S(0)):
|
||||||
return 0
|
# Have a negative cost if we 'recover' from the wrong dependency
|
||||||
|
return 0 if not s.has_head(s.S(0)) else -1
|
||||||
else:
|
else:
|
||||||
# Account for deps we might lose between S0 and stack
|
# Account for deps we might lose between S0 and stack
|
||||||
if not s.has_head(s.S(0)):
|
if not s.has_head(s.S(0)):
|
||||||
|
@ -281,7 +296,7 @@ cdef int _get_root(int word, const GoldParseC* gold) nogil:
|
||||||
cdef class ArcEager(TransitionSystem):
|
cdef class ArcEager(TransitionSystem):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_actions(cls, **kwargs):
|
def get_actions(cls, **kwargs):
|
||||||
actions = kwargs.get('actions',
|
actions = kwargs.get('actions',
|
||||||
{
|
{
|
||||||
SHIFT: {'': True},
|
SHIFT: {'': True},
|
||||||
REDUCE: {'': True},
|
REDUCE: {'': True},
|
||||||
|
@ -294,7 +309,7 @@ cdef class ArcEager(TransitionSystem):
|
||||||
for label in kwargs.get('right_labels', []):
|
for label in kwargs.get('right_labels', []):
|
||||||
if label.upper() != 'ROOT':
|
if label.upper() != 'ROOT':
|
||||||
actions[RIGHT][label] = True
|
actions[RIGHT][label] = True
|
||||||
|
|
||||||
for raw_text, sents in kwargs.get('gold_parses', []):
|
for raw_text, sents in kwargs.get('gold_parses', []):
|
||||||
for (ids, words, tags, heads, labels, iob), ctnts in sents:
|
for (ids, words, tags, heads, labels, iob), ctnts in sents:
|
||||||
for child, head, label in zip(ids, heads, labels):
|
for child, head, label in zip(ids, heads, labels):
|
||||||
|
@ -407,14 +422,14 @@ cdef class ArcEager(TransitionSystem):
|
||||||
for i in range(self.n_moves):
|
for i in range(self.n_moves):
|
||||||
output[i] = is_valid[self.c[i].move]
|
output[i] = is_valid[self.c[i].move]
|
||||||
|
|
||||||
cdef int set_costs(self, int* is_valid, weight_t* costs,
|
cdef int set_costs(self, int* is_valid, weight_t* costs,
|
||||||
StateClass stcls, GoldParse gold) except -1:
|
StateClass stcls, GoldParse gold) except -1:
|
||||||
cdef int i, move, label
|
cdef int i, move, label
|
||||||
cdef label_cost_func_t[N_MOVES] label_cost_funcs
|
cdef label_cost_func_t[N_MOVES] label_cost_funcs
|
||||||
cdef move_cost_func_t[N_MOVES] move_cost_funcs
|
cdef move_cost_func_t[N_MOVES] move_cost_funcs
|
||||||
cdef weight_t[N_MOVES] move_costs
|
cdef weight_t[N_MOVES] move_costs
|
||||||
for i in range(N_MOVES):
|
for i in range(N_MOVES):
|
||||||
move_costs[i] = -1
|
move_costs[i] = 9000
|
||||||
move_cost_funcs[SHIFT] = Shift.move_cost
|
move_cost_funcs[SHIFT] = Shift.move_cost
|
||||||
move_cost_funcs[REDUCE] = Reduce.move_cost
|
move_cost_funcs[REDUCE] = Reduce.move_cost
|
||||||
move_cost_funcs[LEFT] = LeftArc.move_cost
|
move_cost_funcs[LEFT] = LeftArc.move_cost
|
||||||
|
@ -436,14 +451,14 @@ cdef class ArcEager(TransitionSystem):
|
||||||
is_valid[i] = True
|
is_valid[i] = True
|
||||||
move = self.c[i].move
|
move = self.c[i].move
|
||||||
label = self.c[i].label
|
label = self.c[i].label
|
||||||
if move_costs[move] == -1:
|
if move_costs[move] == 9000:
|
||||||
move_costs[move] = move_cost_funcs[move](stcls, &gold.c)
|
move_costs[move] = move_cost_funcs[move](stcls, &gold.c)
|
||||||
costs[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label)
|
costs[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label)
|
||||||
n_gold += costs[i] <= 0
|
n_gold += costs[i] <= 0
|
||||||
else:
|
else:
|
||||||
is_valid[i] = False
|
is_valid[i] = False
|
||||||
costs[i] = 9000
|
costs[i] = 9000
|
||||||
if n_gold == 0:
|
if n_gold < 1:
|
||||||
# Check projectivity --- leading cause
|
# Check projectivity --- leading cause
|
||||||
if is_nonproj_tree(gold.heads):
|
if is_nonproj_tree(gold.heads):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -463,7 +478,7 @@ cdef class ArcEager(TransitionSystem):
|
||||||
"Could not find a gold-standard action to supervise the dependency "
|
"Could not find a gold-standard action to supervise the dependency "
|
||||||
"parser.\n"
|
"parser.\n"
|
||||||
"The GoldParse was projective.\n"
|
"The GoldParse was projective.\n"
|
||||||
"The transition system has %d actions.\n"
|
"The transition system has %d actions.\n"
|
||||||
"State at failure:\n"
|
"State at failure:\n"
|
||||||
"%s" % (self.n_moves, stcls.print_state(gold.words)))
|
"%s" % (self.n_moves, stcls.print_state(gold.words)))
|
||||||
assert n_gold >= 1
|
assert n_gold >= 1
|
||||||
|
|
Loading…
Reference in New Issue
Block a user