mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
* Adjust arc_eager oracle, so that recovering errors via non-monotonic actions gives negative cost. Need to test this with greedy parser.
This commit is contained in:
parent
0bf448461e
commit
1ee6b468a9
|
@ -1,3 +1,6 @@
|
|||
# cython: profile=True
|
||||
# cython: cdivision=True
|
||||
# cython: infer_types=True
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import ctypes
|
||||
|
@ -24,6 +27,7 @@ from .nonproj import PseudoProjectivity
|
|||
DEF NON_MONOTONIC = True
|
||||
DEF USE_BREAK = True
|
||||
|
||||
|
||||
cdef weight_t MIN_SCORE = -90000
|
||||
|
||||
# Break transition from here
|
||||
|
@ -65,10 +69,12 @@ cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) no
|
|||
cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
|
||||
cdef weight_t cost = 0
|
||||
cdef int i, B_i
|
||||
# Count number of words in buffer with deendencies to/from the target.
|
||||
for i in range(stcls.buffer_length()):
|
||||
B_i = stcls.B(i)
|
||||
cost += gold.heads[B_i] == target
|
||||
cost += gold.heads[target] == B_i
|
||||
# TODO: Should re-examine this for German --- it assumes projectivity.
|
||||
if gold.heads[B_i] == B_i or gold.heads[B_i] < target:
|
||||
break
|
||||
if Break.is_valid(stcls.c, -1) and Break.move_cost(stcls, gold) == 0:
|
||||
|
@ -154,7 +160,18 @@ cdef class Reduce:
|
|||
|
||||
@staticmethod
|
||||
cdef inline weight_t move_cost(StateClass st, const GoldParseC* gold) nogil:
|
||||
return pop_cost(st, gold, st.S(0))
|
||||
cost = pop_cost(st, gold, st.S(0))
|
||||
if not st.has_head(st.S(0)):
|
||||
# Decrement cost for the arcs we save
|
||||
for i in range(1, st.stack_depth()):
|
||||
S_i = st.S(i)
|
||||
if gold.heads[st.S(0)] == S_i:
|
||||
cost -= 1
|
||||
if gold.heads[S_i] == st.S(0):
|
||||
cost -= 1
|
||||
if Break.is_valid(st.c, -1) and Break.move_cost(st, gold) == 0:
|
||||
cost -= 1
|
||||
return cost
|
||||
|
||||
@staticmethod
|
||||
cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
|
@ -180,7 +197,8 @@ cdef class LeftArc:
|
|||
cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
|
||||
cdef weight_t cost = 0
|
||||
if arc_is_gold(gold, s.B(0), s.S(0)):
|
||||
return 0
|
||||
# Have a negative cost if we 'recover' from the wrong dependency
|
||||
return 0 if not s.has_head(s.S(0)) else -1
|
||||
else:
|
||||
# Account for deps we might lose between S0 and stack
|
||||
if not s.has_head(s.S(0)):
|
||||
|
@ -407,7 +425,7 @@ cdef class ArcEager(TransitionSystem):
|
|||
cdef move_cost_func_t[N_MOVES] move_cost_funcs
|
||||
cdef weight_t[N_MOVES] move_costs
|
||||
for i in range(N_MOVES):
|
||||
move_costs[i] = -1
|
||||
move_costs[i] = 9000
|
||||
move_cost_funcs[SHIFT] = Shift.move_cost
|
||||
move_cost_funcs[REDUCE] = Reduce.move_cost
|
||||
move_cost_funcs[LEFT] = LeftArc.move_cost
|
||||
|
@ -429,10 +447,10 @@ cdef class ArcEager(TransitionSystem):
|
|||
is_valid[i] = True
|
||||
move = self.c[i].move
|
||||
label = self.c[i].label
|
||||
if move_costs[move] == -1:
|
||||
if move_costs[move] == 9000:
|
||||
move_costs[move] = move_cost_funcs[move](stcls, &gold.c)
|
||||
costs[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label)
|
||||
n_gold += costs[i] == 0
|
||||
n_gold += costs[i] <= 0
|
||||
else:
|
||||
is_valid[i] = False
|
||||
costs[i] = 9000
|
||||
|
|
Loading…
Reference in New Issue
Block a user