mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-11 09:00:36 +03:00
WIP on updating transition-system
This commit is contained in:
parent
7d65615625
commit
60d4e5a9e0
|
@ -16,7 +16,6 @@ from thinc.api import Linear, Model, CupyOps, NumpyOps, use_ops, noop
|
||||||
|
|
||||||
from ..typedefs cimport weight_t, class_t, hash_t
|
from ..typedefs cimport weight_t, class_t, hash_t
|
||||||
from ..tokens.doc cimport Doc
|
from ..tokens.doc cimport Doc
|
||||||
from ..gold cimport GoldParse
|
|
||||||
from .stateclass cimport StateClass
|
from .stateclass cimport StateClass
|
||||||
from .transition_system cimport Transition
|
from .transition_system cimport Transition
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,6 @@ import json
|
||||||
|
|
||||||
from ..typedefs cimport hash_t, attr_t
|
from ..typedefs cimport hash_t, attr_t
|
||||||
from ..strings cimport hash_string
|
from ..strings cimport hash_string
|
||||||
from ..gold cimport GoldParse, GoldParseC
|
|
||||||
from ..structs cimport TokenC
|
from ..structs cimport TokenC
|
||||||
from ..tokens.doc cimport Doc, set_children_from_heads
|
from ..tokens.doc cimport Doc, set_children_from_heads
|
||||||
from .stateclass cimport StateClass
|
from .stateclass cimport StateClass
|
||||||
|
@ -49,40 +48,75 @@ MOVE_NAMES[RIGHT] = 'R'
|
||||||
MOVE_NAMES[BREAK] = 'B'
|
MOVE_NAMES[BREAK] = 'B'
|
||||||
|
|
||||||
|
|
||||||
|
cdef enum:
|
||||||
|
HEAD_ON_STACK = 0
|
||||||
|
HEAD_IN_BUFFER
|
||||||
|
IS_SENT_START
|
||||||
|
HEAD_UNKNOWN
|
||||||
|
|
||||||
|
|
||||||
|
cdef struct GoldParseStateC:
|
||||||
|
char* state_bits
|
||||||
|
attr_t* labels
|
||||||
|
int32_t* heads
|
||||||
|
int32_t* n_kids_in_buffer
|
||||||
|
int32_t* n_kids_on_stack
|
||||||
|
int32_t length
|
||||||
|
int32_t stride
|
||||||
|
|
||||||
|
|
||||||
|
cdef int check_state_flag(char state_bits, char flag) nogil:
|
||||||
|
cdef char one = 1
|
||||||
|
return state_bits & (one << flag)
|
||||||
|
|
||||||
|
|
||||||
|
cdef int set_state_flag(char state_bits, char flag, int value) nogil:
|
||||||
|
cdef char one = 1
|
||||||
|
if value:
|
||||||
|
return state_bits | (one << flag)
|
||||||
|
else:
|
||||||
|
return state_bits & ~(one << flag)
|
||||||
|
|
||||||
|
|
||||||
|
cdef int is_head_on_stack(GoldParseStateC gold, int i) nogil:
|
||||||
|
return check_state_gold(gold.state_bits[i], HEAD_ON_STACK)
|
||||||
|
|
||||||
|
|
||||||
|
cdef int is_head_in_buffer(GoldParseStateC gold, int i) nogil:
|
||||||
|
return check_state_gold(gold.state_bits[i], HEAD_IN_BUFFER)
|
||||||
|
|
||||||
|
|
||||||
|
cdef int is_sent_start(GoldParseStateC gold, int i) nogil:
|
||||||
|
return check_state_gold(gold.state_bits[i], IS_SENT_START)
|
||||||
|
|
||||||
|
|
||||||
|
cdef int is_head_unknown(GoldParseStateC gold, int i) nogil:
|
||||||
|
return check_state_gold(gold.state_bits[i], HEAD_UNKNOWN)
|
||||||
|
|
||||||
|
|
||||||
# Helper functions for the arc-eager oracle
|
# Helper functions for the arc-eager oracle
|
||||||
|
|
||||||
cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
|
cdef weight_t push_cost(StateClass stcls, const GoldParseStateC* gold, int target) nogil:
|
||||||
cdef weight_t cost = 0
|
cdef weight_t cost = 0
|
||||||
cdef int i, S_i
|
if is_head_in_stack(gold[0], target):
|
||||||
for i in range(stcls.stack_depth()):
|
cost += 1
|
||||||
S_i = stcls.S(i)
|
cost += gold.n_kids_in_buffer[target]
|
||||||
if gold.heads[target] == S_i:
|
|
||||||
cost += 1
|
|
||||||
if gold.heads[S_i] == target and (NON_MONOTONIC or not stcls.has_head(S_i)):
|
|
||||||
cost += 1
|
|
||||||
if BINARY_COSTS and cost >= 1:
|
|
||||||
return cost
|
|
||||||
cost += Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0
|
|
||||||
return cost
|
|
||||||
|
|
||||||
|
|
||||||
cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
|
|
||||||
cdef weight_t cost = 0
|
|
||||||
cdef int i, B_i
|
|
||||||
for i in range(stcls.buffer_length()):
|
|
||||||
B_i = stcls.B(i)
|
|
||||||
cost += gold.heads[B_i] == target
|
|
||||||
cost += gold.heads[target] == B_i
|
|
||||||
if gold.heads[B_i] == B_i or gold.heads[B_i] < target:
|
|
||||||
break
|
|
||||||
if BINARY_COSTS and cost >= 1:
|
|
||||||
return cost
|
|
||||||
if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0:
|
if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0:
|
||||||
cost += 1
|
cost += 1
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
|
|
||||||
cdef weight_t arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) nogil:
|
cdef weight_t pop_cost(StateClass stcls, const GoldParseStateC* gold, int target) nogil:
|
||||||
|
cdef weight_t cost = 0
|
||||||
|
if is_head_in_buffer(gold[0], target):
|
||||||
|
cost += 1
|
||||||
|
cost += gold[0].n_kids_in_buffer[target]
|
||||||
|
if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0:
|
||||||
|
cost += 1
|
||||||
|
return cost
|
||||||
|
|
||||||
|
|
||||||
|
cdef weight_t arc_cost(StateClass stcls, const GoldParseStateC* gold, int head, int child) nogil:
|
||||||
if arc_is_gold(gold, head, child):
|
if arc_is_gold(gold, head, child):
|
||||||
return 0
|
return 0
|
||||||
elif stcls.H(child) == gold.heads[child]:
|
elif stcls.H(child) == gold.heads[child]:
|
||||||
|
@ -94,8 +128,8 @@ cdef weight_t arc_cost(StateClass stcls, const GoldParseC* gold, int head, int c
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) nogil:
|
cdef bint arc_is_gold(const GoldParseStateC* gold, int head, int child) nogil:
|
||||||
if not gold.has_dep[child]:
|
if is_head_unknown(gold[0], child):
|
||||||
return True
|
return True
|
||||||
elif gold.heads[child] == head:
|
elif gold.heads[child] == head:
|
||||||
return True
|
return True
|
||||||
|
@ -103,8 +137,8 @@ cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) nogil:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
cdef bint label_is_gold(const GoldParseC* gold, int head, int child, attr_t label) nogil:
|
cdef bint label_is_gold(const GoldParseStateC* gold, int head, int child, attr_t label) nogil:
|
||||||
if not gold.has_dep[child]:
|
if is_head_unknown(gold[0], child):
|
||||||
return True
|
return True
|
||||||
elif label == 0:
|
elif label == 0:
|
||||||
return True
|
return True
|
||||||
|
@ -114,8 +148,9 @@ cdef bint label_is_gold(const GoldParseC* gold, int head, int child, attr_t labe
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil:
|
cdef bint _is_gold_root(const GoldParseStateC* gold, int word) nogil:
|
||||||
return gold.heads[word] == word or not gold.has_dep[word]
|
return gold.heads[word] == word or is_head_unknown(gold[0], word)
|
||||||
|
|
||||||
|
|
||||||
cdef class Shift:
|
cdef class Shift:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -129,15 +164,16 @@ cdef class Shift:
|
||||||
st.fast_forward()
|
st.fast_forward()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef weight_t cost(StateClass st, const GoldParseC* gold, attr_t label) nogil:
|
cdef weight_t cost(StateClass st, const void* _gold, attr_t label) nogil:
|
||||||
|
gold = <const GoldParseStateC*>_gold
|
||||||
return Shift.move_cost(st, gold) + Shift.label_cost(st, gold, label)
|
return Shift.move_cost(st, gold) + Shift.label_cost(st, gold, label)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
|
cdef inline weight_t move_cost(StateClass s, const GoldParseStateC* gold) nogil:
|
||||||
return push_cost(s, gold, s.B(0))
|
return push_cost(s, gold, s.B(0))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
cdef inline weight_t label_cost(StateClass s, const GoldParseStateC* gold, attr_t label) nogil:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
@ -155,26 +191,27 @@ cdef class Reduce:
|
||||||
st.fast_forward()
|
st.fast_forward()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
|
||||||
|
gold = <const GoldParseStateC*>_gold
|
||||||
return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label)
|
return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef inline weight_t move_cost(StateClass st, const GoldParseC* gold) nogil:
|
cdef inline weight_t move_cost(StateClass st, const GoldParseStateC* gold) nogil:
|
||||||
cost = pop_cost(st, gold, st.S(0))
|
s0 = st.S(0)
|
||||||
if not st.has_head(st.S(0)):
|
cost = pop_cost(st, gold, s0)
|
||||||
# Decrement cost for the arcs e save
|
return_to_buffer = not st.has_head(s0)
|
||||||
for i in range(1, st.stack_depth()):
|
if return_to_buffer:
|
||||||
S_i = st.S(i)
|
# Decrement cost for the arcs we save, as we'll be putting this
|
||||||
if gold.heads[st.S(0)] == S_i:
|
# back to the buffer
|
||||||
cost -= 1
|
if is_head_in_stack(gold[0], s0):
|
||||||
if gold.heads[S_i] == st.S(0):
|
cost -= 1
|
||||||
cost -= 1
|
cost -= gold.n_kids_in_stack[s0]
|
||||||
if Break.is_valid(st.c, 0) and Break.move_cost(st, gold) == 0:
|
if Break.is_valid(st.c, 0) and Break.move_cost(st, gold) == 0:
|
||||||
cost -= 1
|
cost -= 1
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
cdef inline weight_t label_cost(StateClass s, const GoldParseStateC* gold, attr_t label) nogil:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
@ -193,49 +230,12 @@ cdef class LeftArc:
|
||||||
st.fast_forward()
|
st.fast_forward()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
cdef inline weight_t cost(StateClass s, const void* gold, attr_t label) nogil:
|
||||||
return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label)
|
gold = <const GoldParseStateC*>_gold
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
|
|
||||||
cdef weight_t cost = 0
|
|
||||||
if arc_is_gold(gold, s.B(0), s.S(0)):
|
|
||||||
# Have a negative cost if we 'recover' from the wrong dependency
|
|
||||||
return 0 if not s.has_head(s.S(0)) else -1
|
|
||||||
else:
|
|
||||||
# Account for deps we might lose between S0 and stack
|
|
||||||
if not s.has_head(s.S(0)):
|
|
||||||
for i in range(1, s.stack_depth()):
|
|
||||||
cost += gold.heads[s.S(i)] == s.S(0)
|
|
||||||
cost += gold.heads[s.S(0)] == s.S(i)
|
|
||||||
return cost + pop_cost(s, gold, s.S(0)) + arc_cost(s, gold, s.B(0), s.S(0))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
|
||||||
return arc_is_gold(gold, s.B(0), s.S(0)) and not label_is_gold(gold, s.B(0), s.S(0), label)
|
|
||||||
|
|
||||||
|
|
||||||
cdef class RightArc:
|
|
||||||
@staticmethod
|
|
||||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
|
||||||
# If there's (perhaps partial) parse pre-set, don't allow cycle.
|
|
||||||
if label == SUBTOK_LABEL and st.S(0) != (st.B(0)-1):
|
|
||||||
return 0
|
|
||||||
sent_start = st._sent[st.B_(0).l_edge].sent_start
|
|
||||||
return sent_start != 1 and st.H(st.S(0)) != st.B(0)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
cdef int transition(StateC* st, attr_t label) nogil:
|
|
||||||
st.add_arc(st.S(0), st.B(0), label)
|
|
||||||
st.push()
|
|
||||||
st.fast_forward()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
cdef inline weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
|
||||||
return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
|
return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
|
cdef inline weight_t move_cost(StateClass s, const GoldParseStateC* gold) nogil:
|
||||||
if arc_is_gold(gold, s.S(0), s.B(0)):
|
if arc_is_gold(gold, s.S(0), s.B(0)):
|
||||||
return 0
|
return 0
|
||||||
elif s.c.shifted[s.B(0)]:
|
elif s.c.shifted[s.B(0)]:
|
||||||
|
@ -244,7 +244,7 @@ cdef class RightArc:
|
||||||
return push_cost(s, gold, s.B(0)) + arc_cost(s, gold, s.S(0), s.B(0))
|
return push_cost(s, gold, s.B(0)) + arc_cost(s, gold, s.S(0), s.B(0))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
cdef weight_t label_cost(StateClass s, const GoldParseStateC* gold, attr_t label) nogil:
|
||||||
return arc_is_gold(gold, s.S(0), s.B(0)) and not label_is_gold(gold, s.S(0), s.B(0), label)
|
return arc_is_gold(gold, s.S(0), s.B(0)) and not label_is_gold(gold, s.S(0), s.B(0), label)
|
||||||
|
|
||||||
|
|
||||||
|
@ -271,21 +271,19 @@ cdef class Break:
|
||||||
st.fast_forward()
|
st.fast_forward()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
|
||||||
|
gold = <const GoldParseStateC*>_gold
|
||||||
return Break.move_cost(s, gold) + Break.label_cost(s, gold, label)
|
return Break.move_cost(s, gold) + Break.label_cost(s, gold, label)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
|
cdef inline weight_t move_cost(StateClass s, const GoldParseStateC* gold) nogil:
|
||||||
cdef weight_t cost = 0
|
cdef weight_t cost = 0
|
||||||
cdef int i, j, S_i, B_i
|
cdef int i, j, S_i, B_i
|
||||||
for i in range(s.stack_depth()):
|
for i in range(s.stack_depth()):
|
||||||
S_i = s.S(i)
|
S_i = s.S(i)
|
||||||
for j in range(s.buffer_length()):
|
cost += gold.n_kids_in_buffer[S_i]
|
||||||
B_i = s.B(j)
|
if is_head_in_buffer(gold[0], S_i):
|
||||||
cost += gold.heads[S_i] == B_i
|
cost += 1
|
||||||
cost += gold.heads[B_i] == S_i
|
|
||||||
if cost != 0:
|
|
||||||
return cost
|
|
||||||
# Check for sentence boundary --- if it's here, we can't have any deps
|
# Check for sentence boundary --- if it's here, we can't have any deps
|
||||||
# between stack and buffer, so rest of action is irrelevant.
|
# between stack and buffer, so rest of action is irrelevant.
|
||||||
s0_root = _get_root(s.S(0), gold)
|
s0_root = _get_root(s.S(0), gold)
|
||||||
|
@ -296,14 +294,16 @@ cdef class Break:
|
||||||
return cost + 1
|
return cost + 1
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
cdef inline weight_t label_cost(StateClass s, const GoldParseStateC* gold, attr_t label) nogil:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
cdef int _get_root(int word, const GoldParseC* gold) nogil:
|
cdef int _get_root(int word, const GoldParseStateC* gold) nogil:
|
||||||
while gold.heads[word] != word and gold.has_dep[word] and word >= 0:
|
if is_head_unset(gold[0], word):
|
||||||
word = gold.heads[word]
|
|
||||||
if not gold.has_dep[word]:
|
|
||||||
return -1
|
return -1
|
||||||
|
while gold.heads[word] != word and word >= 0:
|
||||||
|
word = gold.heads[word]
|
||||||
|
if is_head_unset(gold[0], word):
|
||||||
|
return -1
|
||||||
else:
|
else:
|
||||||
return word
|
return word
|
||||||
|
|
||||||
|
@ -378,86 +378,22 @@ cdef class ArcEager(TransitionSystem):
|
||||||
def action_types(self):
|
def action_types(self):
|
||||||
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
|
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
|
||||||
|
|
||||||
def get_cost(self, StateClass state, GoldParse gold, action):
|
def get_cost(self, StateClass state, NewExample gold, action):
|
||||||
cdef Transition t = self.lookup_transition(action)
|
raise NotImplementedError
|
||||||
if not t.is_valid(state.c, t.label):
|
|
||||||
return 9000
|
|
||||||
else:
|
|
||||||
return t.get_cost(state, &gold.c, t.label)
|
|
||||||
|
|
||||||
def transition(self, StateClass state, action):
|
def transition(self, StateClass state, action):
|
||||||
cdef Transition t = self.lookup_transition(action)
|
cdef Transition t = self.lookup_transition(action)
|
||||||
t.do(state.c, t.label)
|
t.do(state.c, t.label)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def is_gold_parse(self, StateClass state, GoldParse gold):
|
def is_gold_parse(self, StateClass state, gold):
|
||||||
predicted = set()
|
raise NotImplementedError
|
||||||
truth = set()
|
|
||||||
for i in range(gold.length):
|
|
||||||
if gold.cand_to_gold[i] is None:
|
|
||||||
continue
|
|
||||||
if state.safe_get(i).dep:
|
|
||||||
predicted.add((i, state.H(i),
|
|
||||||
self.strings[state.safe_get(i).dep]))
|
|
||||||
else:
|
|
||||||
predicted.add((i, state.H(i), 'ROOT'))
|
|
||||||
id_ = gold.orig.ids[gold.cand_to_gold[i]]
|
|
||||||
head = gold.orig.heads[gold.cand_to_gold[i]]
|
|
||||||
dep = gold.orig.deps[gold.cand_to_gold[i]]
|
|
||||||
truth.add((id_, head, dep))
|
|
||||||
return truth == predicted
|
|
||||||
|
|
||||||
def has_gold(self, GoldParse gold, start=0, end=None):
|
def has_gold(self, gold, start=0, end=None):
|
||||||
end = end or len(gold.heads)
|
raise NotImplementedError
|
||||||
if all([tag is None for tag in gold.heads[start:end]]):
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def preprocess_gold(self, GoldParse gold):
|
def preprocess_gold(self, gold):
|
||||||
if not self.has_gold(gold):
|
raise NotImplementedError
|
||||||
return None
|
|
||||||
# Figure out whether we're using subtok
|
|
||||||
use_subtok = False
|
|
||||||
for action, labels in self.labels.items():
|
|
||||||
if SUBTOK_LABEL in labels:
|
|
||||||
use_subtok = True
|
|
||||||
break
|
|
||||||
for i, (head, dep) in enumerate(zip(gold.heads, gold.labels)):
|
|
||||||
# Missing values
|
|
||||||
if head is None or dep is None:
|
|
||||||
gold.c.heads[i] = i
|
|
||||||
gold.c.has_dep[i] = False
|
|
||||||
elif dep == SUBTOK_LABEL and not use_subtok:
|
|
||||||
# If we're not doing the joint tokenization and parsing,
|
|
||||||
# regard these subtok labels as missing
|
|
||||||
gold.c.heads[i] = i
|
|
||||||
gold.c.labels[i] = 0
|
|
||||||
gold.c.has_dep[i] = False
|
|
||||||
else:
|
|
||||||
if head > i:
|
|
||||||
action = LEFT
|
|
||||||
elif head < i:
|
|
||||||
action = RIGHT
|
|
||||||
else:
|
|
||||||
action = BREAK
|
|
||||||
if dep not in self.labels[action]:
|
|
||||||
if action == BREAK:
|
|
||||||
dep = 'ROOT'
|
|
||||||
elif nonproj.is_decorated(dep):
|
|
||||||
backoff = nonproj.decompose(dep)[0]
|
|
||||||
if backoff in self.labels[action]:
|
|
||||||
dep = backoff
|
|
||||||
else:
|
|
||||||
dep = 'dep'
|
|
||||||
else:
|
|
||||||
dep = 'dep'
|
|
||||||
gold.c.has_dep[i] = True
|
|
||||||
if dep.upper() == 'ROOT':
|
|
||||||
dep = 'ROOT'
|
|
||||||
gold.c.heads[i] = head
|
|
||||||
gold.c.labels[i] = self.strings.add(dep)
|
|
||||||
return gold
|
|
||||||
|
|
||||||
def get_beam_parses(self, Beam beam):
|
def get_beam_parses(self, Beam beam):
|
||||||
parses = []
|
parses = []
|
||||||
|
@ -569,7 +505,9 @@ cdef class ArcEager(TransitionSystem):
|
||||||
output[i] = is_valid[self.c[i].move]
|
output[i] = is_valid[self.c[i].move]
|
||||||
|
|
||||||
cdef int set_costs(self, int* is_valid, weight_t* costs,
|
cdef int set_costs(self, int* is_valid, weight_t* costs,
|
||||||
StateClass stcls, GoldParse gold) except -1:
|
StateClass stcls, NewExample example) except -1:
|
||||||
|
cdef Pool mem = Pool()
|
||||||
|
gold_state = create_gold_state(mem, stcls, example)
|
||||||
cdef int i, move
|
cdef int i, move
|
||||||
cdef attr_t label
|
cdef attr_t label
|
||||||
cdef label_cost_func_t[N_MOVES] label_cost_funcs
|
cdef label_cost_func_t[N_MOVES] label_cost_funcs
|
||||||
|
@ -599,8 +537,8 @@ cdef class ArcEager(TransitionSystem):
|
||||||
move = self.c[i].move
|
move = self.c[i].move
|
||||||
label = self.c[i].label
|
label = self.c[i].label
|
||||||
if move_costs[move] == 9000:
|
if move_costs[move] == 9000:
|
||||||
move_costs[move] = move_cost_funcs[move](stcls, &gold.c)
|
move_costs[move] = move_cost_funcs[move](stcls, gold_state)
|
||||||
costs[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label)
|
costs[i] = move_costs[move] + label_cost_funcs[move](stcls, gold_state, label)
|
||||||
n_gold += costs[i] <= 0
|
n_gold += costs[i] <= 0
|
||||||
else:
|
else:
|
||||||
is_valid[i] = False
|
is_valid[i] = False
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
from .transition_system cimport TransitionSystem
|
from .transition_system cimport TransitionSystem
|
||||||
from .transition_system cimport Transition
|
from .transition_system cimport Transition
|
||||||
from ..gold cimport GoldParseC
|
|
||||||
from ..typedefs cimport attr_t
|
from ..typedefs cimport attr_t
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -7,11 +7,11 @@ from .stateclass cimport StateClass
|
||||||
from ._state cimport StateC
|
from ._state cimport StateC
|
||||||
from .transition_system cimport Transition
|
from .transition_system cimport Transition
|
||||||
from .transition_system cimport do_func_t
|
from .transition_system cimport do_func_t
|
||||||
from ..gold cimport GoldParseC, GoldParse
|
|
||||||
from ..lexeme cimport Lexeme
|
from ..lexeme cimport Lexeme
|
||||||
from ..attrs cimport IS_SPACE
|
from ..attrs cimport IS_SPACE
|
||||||
|
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
|
from .gold_parse cimport GoldParseC
|
||||||
|
|
||||||
|
|
||||||
cdef enum:
|
cdef enum:
|
||||||
|
@ -91,19 +91,11 @@ cdef class BiluoPushDown(TransitionSystem):
|
||||||
else:
|
else:
|
||||||
return MOVE_NAMES[move] + '-' + self.strings[label]
|
return MOVE_NAMES[move] + '-' + self.strings[label]
|
||||||
|
|
||||||
def has_gold(self, GoldParse gold, start=0, end=None):
|
def has_gold(self, gold, start=0, end=None):
|
||||||
end = end or len(gold.ner)
|
raise NotImplementedError
|
||||||
if all([tag in ('-', None) for tag in gold.ner[start:end]]):
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def preprocess_gold(self, GoldParse gold):
|
def preprocess_gold(self, gold):
|
||||||
if not self.has_gold(gold):
|
raise NotImplementedError
|
||||||
return None
|
|
||||||
for i in range(gold.length):
|
|
||||||
gold.c.ner[i] = self.lookup_transition(gold.ner[i])
|
|
||||||
return gold
|
|
||||||
|
|
||||||
def get_beam_annot(self, Beam beam):
|
def get_beam_annot(self, Beam beam):
|
||||||
entities = {}
|
entities = {}
|
||||||
|
@ -248,7 +240,7 @@ cdef class Missing:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
|
||||||
return 9000
|
return 9000
|
||||||
|
|
||||||
|
|
||||||
|
@ -300,7 +292,8 @@ cdef class Begin:
|
||||||
st.pop()
|
st.pop()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
|
||||||
|
gold = <GoldParseC*>_gold
|
||||||
cdef int g_act = gold.ner[s.B(0)].move
|
cdef int g_act = gold.ner[s.B(0)].move
|
||||||
cdef attr_t g_tag = gold.ner[s.B(0)].label
|
cdef attr_t g_tag = gold.ner[s.B(0)].label
|
||||||
|
|
||||||
|
@ -363,7 +356,8 @@ cdef class In:
|
||||||
st.pop()
|
st.pop()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
|
||||||
|
gold = <GoldParseC*>_gold
|
||||||
move = IN
|
move = IN
|
||||||
cdef int next_act = gold.ner[s.B(1)].move if s.B(1) >= 0 else OUT
|
cdef int next_act = gold.ner[s.B(1)].move if s.B(1) >= 0 else OUT
|
||||||
cdef int g_act = gold.ner[s.B(0)].move
|
cdef int g_act = gold.ner[s.B(0)].move
|
||||||
|
@ -429,7 +423,8 @@ cdef class Last:
|
||||||
st.pop()
|
st.pop()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
|
||||||
|
gold = <GoldParseC*>_gold
|
||||||
move = LAST
|
move = LAST
|
||||||
|
|
||||||
cdef int g_act = gold.ner[s.B(0)].move
|
cdef int g_act = gold.ner[s.B(0)].move
|
||||||
|
@ -497,7 +492,8 @@ cdef class Unit:
|
||||||
st.pop()
|
st.pop()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
|
||||||
|
gold = <GoldParseC*>_gold
|
||||||
cdef int g_act = gold.ner[s.B(0)].move
|
cdef int g_act = gold.ner[s.B(0)].move
|
||||||
cdef attr_t g_tag = gold.ner[s.B(0)].label
|
cdef attr_t g_tag = gold.ner[s.B(0)].label
|
||||||
|
|
||||||
|
@ -537,7 +533,8 @@ cdef class Out:
|
||||||
st.pop()
|
st.pop()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
|
||||||
|
gold = <GoldParseC*>_gold
|
||||||
cdef int g_act = gold.ner[s.B(0)].move
|
cdef int g_act = gold.ner[s.B(0)].move
|
||||||
cdef attr_t g_tag = gold.ner[s.B(0)].label
|
cdef attr_t g_tag = gold.ner[s.B(0)].label
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ from ..structs cimport TokenC
|
||||||
from ..strings cimport StringStore
|
from ..strings cimport StringStore
|
||||||
from .stateclass cimport StateClass
|
from .stateclass cimport StateClass
|
||||||
from ._state cimport StateC
|
from ._state cimport StateC
|
||||||
|
from ..gold.new_example cimport NewExample
|
||||||
|
|
||||||
|
|
||||||
cdef struct Transition:
|
cdef struct Transition:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user