mirror of
https://github.com/explosion/spaCy.git
synced 2025-03-12 15:25:47 +03:00
Draft create_gold_state for arc_eager oracle
This commit is contained in:
parent
41d29983a7
commit
95de7efaad
|
@ -52,7 +52,6 @@ MOVE_NAMES[BREAK] = 'B'
|
||||||
cdef enum:
|
cdef enum:
|
||||||
HEAD_IN_STACK = 0
|
HEAD_IN_STACK = 0
|
||||||
HEAD_IN_BUFFER
|
HEAD_IN_BUFFER
|
||||||
IS_SENT_START
|
|
||||||
HEAD_UNKNOWN
|
HEAD_UNKNOWN
|
||||||
|
|
||||||
|
|
||||||
|
@ -65,8 +64,72 @@ cdef struct GoldParseStateC:
|
||||||
int32_t length
|
int32_t length
|
||||||
int32_t stride
|
int32_t stride
|
||||||
|
|
||||||
|
|
||||||
cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example example) except *:
|
cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example example) except *:
|
||||||
cdef GoldParseStateC gs
|
cdef GoldParseStateC gs
|
||||||
|
gs.length = example.x.length
|
||||||
|
gs.stride = 1
|
||||||
|
gs.state_bits = <char*>mem.alloc(gs.length, sizeof(gs.state_bits[0]))
|
||||||
|
gs.labels = <attr_t*>mem.alloc(gs.length, sizeof(gs.labels[0]))
|
||||||
|
gs.heads = <int32_t*>mem.alloc(gs.length, sizeof(gs.heads[0]))
|
||||||
|
gs.n_kids_in_buffer = <int32_t*>mem.alloc(gs.length, sizeof(gs.n_kids_in_buffer[0]))
|
||||||
|
gs.n_kids_in_stack = <int32_t*>mem.alloc(gs.length, sizeof(gs.n_kids_in_stack[0]))
|
||||||
|
|
||||||
|
cand_to_gold = example.alignment.cand_to_gold
|
||||||
|
cdef TokenC ref_tok
|
||||||
|
for cand_i in range(example.x.length):
|
||||||
|
gold_i = cand_to_gold[cand_i]
|
||||||
|
if cand_i is not None: # Alignment found
|
||||||
|
ref_tok = example.y.c[gold_i]
|
||||||
|
gs.heads[cand_i] = ref_tok.head
|
||||||
|
gs.labels[cand_i] = ref_tok.dep
|
||||||
|
gs.state_bits[cand_i] = set_state_flag(
|
||||||
|
gs.state_bits[cand_i],
|
||||||
|
HEAD_UNKNOWN,
|
||||||
|
0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
gs.state_bits[cand_i] = set_state_flag(
|
||||||
|
gs.state_bits[cand_i],
|
||||||
|
HEAD_UNKNOWN,
|
||||||
|
1
|
||||||
|
)
|
||||||
|
stack_words = set()
|
||||||
|
for i in range(stcls.stack_depth()):
|
||||||
|
s_i = stcls.S(i)
|
||||||
|
head = s_i + gs.heads[s_i]
|
||||||
|
gs.n_kids_in_stack[head] += 1
|
||||||
|
stack_words.add(s_i)
|
||||||
|
buffer_words = set()
|
||||||
|
for i in range(stcls.buffer_length()):
|
||||||
|
b_i = stcls.B(i)
|
||||||
|
head = b_i + gs.heads[b_i]
|
||||||
|
gs.n_kids_in_buffer[head] += 1
|
||||||
|
buffer_words.add(b_i)
|
||||||
|
for i in range(gs.length):
|
||||||
|
head = i + gs.heads[i]
|
||||||
|
if head in stack_words:
|
||||||
|
gs.state_bits[i] = set_state_flag(
|
||||||
|
gs.state_bits[i],
|
||||||
|
HEAD_IN_STACK,
|
||||||
|
1
|
||||||
|
)
|
||||||
|
gs.state_bits[i] = set_state_flag(
|
||||||
|
gs.state_bits[i],
|
||||||
|
HEAD_IN_BUFFER,
|
||||||
|
0
|
||||||
|
)
|
||||||
|
elif head in buffer_words:
|
||||||
|
gs.state_bits[i] = set_state_flag(
|
||||||
|
gs.state_bits[i],
|
||||||
|
HEAD_IN_STACK,
|
||||||
|
0
|
||||||
|
)
|
||||||
|
gs.state_bits[i] = set_state_flag(
|
||||||
|
gs.state_bits[i],
|
||||||
|
HEAD_IN_BUFFER,
|
||||||
|
1
|
||||||
|
)
|
||||||
return gs
|
return gs
|
||||||
|
|
||||||
cdef int check_state_gold(char state_bits, char flag) nogil:
|
cdef int check_state_gold(char state_bits, char flag) nogil:
|
||||||
|
@ -90,10 +153,6 @@ cdef int is_head_in_buffer(const GoldParseStateC* gold, int i) nogil:
|
||||||
return check_state_gold(gold.state_bits[i], HEAD_IN_BUFFER)
|
return check_state_gold(gold.state_bits[i], HEAD_IN_BUFFER)
|
||||||
|
|
||||||
|
|
||||||
cdef int is_sent_start(const GoldParseStateC* gold, int i) nogil:
|
|
||||||
return check_state_gold(gold.state_bits[i], IS_SENT_START)
|
|
||||||
|
|
||||||
|
|
||||||
cdef int is_head_unknown(const GoldParseStateC* gold, int i) nogil:
|
cdef int is_head_unknown(const GoldParseStateC* gold, int i) nogil:
|
||||||
return check_state_gold(gold.state_bits[i], HEAD_UNKNOWN)
|
return check_state_gold(gold.state_bits[i], HEAD_UNKNOWN)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user