From 95de7efaadc742bab2e10d724244daac8074326d Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 15 Jun 2020 18:10:19 +0200 Subject: [PATCH] Draft create_gold_state for arc_eager oracle --- spacy/syntax/arc_eager.pyx | 69 +++++++++++++++++++++++++++++++++++--- 1 file changed, 64 insertions(+), 5 deletions(-) diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 787546144..899cc0760 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -52,7 +52,6 @@ MOVE_NAMES[BREAK] = 'B' cdef enum: HEAD_IN_STACK = 0 HEAD_IN_BUFFER - IS_SENT_START HEAD_UNKNOWN @@ -65,8 +64,72 @@ cdef struct GoldParseStateC: int32_t length int32_t stride + cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example example) except *: cdef GoldParseStateC gs + gs.length = example.x.length + gs.stride = 1 + gs.state_bits = mem.alloc(gs.length, sizeof(gs.state_bits[0])) + gs.labels = mem.alloc(gs.length, sizeof(gs.labels[0])) + gs.heads = mem.alloc(gs.length, sizeof(gs.heads[0])) + gs.n_kids_in_buffer = mem.alloc(gs.length, sizeof(gs.n_kids_in_buffer[0])) + gs.n_kids_in_stack = mem.alloc(gs.length, sizeof(gs.n_kids_in_stack[0])) + + cand_to_gold = example.alignment.cand_to_gold + cdef TokenC ref_tok + for cand_i in range(example.x.length): + gold_i = cand_to_gold[cand_i] + if cand_i is not None: # Alignment found + ref_tok = example.y.c[gold_i] + gs.heads[cand_i] = ref_tok.head + gs.labels[cand_i] = ref_tok.dep + gs.state_bits[cand_i] = set_state_flag( + gs.state_bits[cand_i], + HEAD_UNKNOWN, + 0 + ) + else: + gs.state_bits[cand_i] = set_state_flag( + gs.state_bits[cand_i], + HEAD_UNKNOWN, + 1 + ) + stack_words = set() + for i in range(stcls.stack_depth()): + s_i = stcls.S(i) + head = s_i + gs.heads[s_i] + gs.n_kids_in_stack[head] += 1 + stack_words.add(s_i) + buffer_words = set() + for i in range(stcls.buffer_length()): + b_i = stcls.B(i) + head = b_i + gs.heads[b_i] + gs.n_kids_in_buffer[head] += 1 + buffer_words.add(b_i) + for i in range(gs.length): + head = i + gs.heads[i] + if head in stack_words: + gs.state_bits[i] = set_state_flag( + gs.state_bits[i], + HEAD_IN_STACK, + 1 + ) + gs.state_bits[i] = set_state_flag( + gs.state_bits[i], + HEAD_IN_BUFFER, + 0 + ) + elif head in buffer_words: + gs.state_bits[i] = set_state_flag( + gs.state_bits[i], + HEAD_IN_STACK, + 0 + ) + gs.state_bits[i] = set_state_flag( + gs.state_bits[i], + HEAD_IN_BUFFER, + 1 + ) return gs cdef int check_state_gold(char state_bits, char flag) nogil: @@ -90,10 +153,6 @@ cdef int is_head_in_buffer(const GoldParseStateC* gold, int i) nogil: return check_state_gold(gold.state_bits[i], HEAD_IN_BUFFER) -cdef int is_sent_start(const GoldParseStateC* gold, int i) nogil: - return check_state_gold(gold.state_bits[i], IS_SENT_START) - - cdef int is_head_unknown(const GoldParseStateC* gold, int i) nogil: return check_state_gold(gold.state_bits[i], HEAD_UNKNOWN)