From d5212f7ba8c3735a8a1118e9b844563bda6aa452 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 23 Jun 2020 15:55:12 +0200 Subject: [PATCH] Improve efficiency of ArEager oracle --- spacy/syntax/arc_eager.pyx | 123 ++++++++++++++----------------------- 1 file changed, 47 insertions(+), 76 deletions(-) diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 28787f97d..a6bc10f0c 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -1,6 +1,6 @@ # cython: profile=True, cdivision=True, infer_types=True from cpython.ref cimport Py_INCREF -from cymem.cymem cimport Pool +from cymem.cymem cimport Pool, Address from libc.stdint cimport int32_t from collections import defaultdict, Counter @@ -59,26 +59,28 @@ cdef enum: cdef struct GoldParseStateC: char* state_bits - attr_t* labels - int32_t* heads int32_t* n_kids_in_buffer int32_t* n_kids_in_stack + int32_t* heads + attr_t* labels + int32_t** kids + int32_t* n_kids int32_t length int32_t stride -cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example example) except *: +cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, + heads, labels, sent_starts) except *: cdef GoldParseStateC gs - gs.length = example.x.length + gs.length = len(heads) gs.stride = 1 - gs.state_bits = mem.alloc(gs.length, sizeof(gs.state_bits[0])) gs.labels = mem.alloc(gs.length, sizeof(gs.labels[0])) gs.heads = mem.alloc(gs.length, sizeof(gs.heads[0])) + gs.n_kids = mem.alloc(gs.length, sizeof(gs.n_kids[0])) + gs.state_bits = mem.alloc(gs.length, sizeof(gs.state_bits[0])) gs.n_kids_in_buffer = mem.alloc(gs.length, sizeof(gs.n_kids_in_buffer[0])) gs.n_kids_in_stack = mem.alloc(gs.length, sizeof(gs.n_kids_in_stack[0])) - heads, labels = example.get_aligned_parse(projectivize=True) - sent_starts = example.get_aligned("SENT_START") for i, is_sent_start in enumerate(sent_starts): if is_sent_start == True: gs.state_bits[i] = set_state_flag( @@ -115,11 +117,12 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp 0 ) - cdef TokenC ref_tok for i, (head, label) in enumerate(zip(heads, labels)): if head is not None: gs.heads[i] = head - gs.labels[i] = example.x.vocab.strings.add(label) + gs.labels[i] = label + if i != head: + gs.n_kids[head] += 1 gs.state_bits[i] = set_state_flag( gs.state_bits[i], HEAD_UNKNOWN, @@ -131,46 +134,24 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp HEAD_UNKNOWN, 1 ) - stack_words = set() - for i in range(stcls.stack_depth()): - s_i = stcls.S(i) - head = gs.heads[s_i] - gs.n_kids_in_stack[head] += 1 - stack_words.add(s_i) - buffer_words = set() - for i in range(stcls.buffer_length()): - b_i = stcls.B(i) - head = gs.heads[b_i] - gs.n_kids_in_buffer[head] += 1 - buffer_words.add(b_i) + # Make an array of pointers, pointing into the gs_kids_flat array. + gs.kids = mem.alloc(gs.length, sizeof(int32_t*)) for i in range(gs.length): - head = gs.heads[i] - if head in stack_words: - gs.state_bits[i] = set_state_flag( - gs.state_bits[i], - HEAD_IN_STACK, - 1 - ) - gs.state_bits[i] = set_state_flag( - gs.state_bits[i], - HEAD_IN_BUFFER, - 0 - ) - elif head in buffer_words: - gs.state_bits[i] = set_state_flag( - gs.state_bits[i], - HEAD_IN_STACK, - 0 - ) - gs.state_bits[i] = set_state_flag( - gs.state_bits[i], - HEAD_IN_BUFFER, - 1 - ) + if gs.n_kids[i] != 0: + gs.kids[i] = mem.alloc(gs.n_kids[i], sizeof(int32_t)) + # This is a temporary buffer + js_addr = Address(gs.length, sizeof(int32_t)) + js = js_addr.ptr + for i in range(gs.length): + if not is_head_unknown(&gs, i): + head = gs.heads[i] + if head != i: + gs.kids[head][js[head]] = i + js[head] += 1 return gs -cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) except *: +cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) nogil: for i in range(gs.length): gs.state_bits[i] = set_state_flag( gs.state_bits[i], @@ -184,39 +165,24 @@ cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) except *: ) gs.n_kids_in_stack[i] = 0 gs.n_kids_in_buffer[i] = 0 - stack_words = set() + for i in range(stcls.stack_depth()): s_i = stcls.S(i) - head = gs.heads[s_i] - gs.n_kids_in_stack[head] += 1 - stack_words.add(s_i) - buffer_words = set() - for i in range(stcls.buffer_length()): - b_i = stcls.B(i) - head = gs.heads[b_i] - gs.n_kids_in_buffer[head] += 1 - buffer_words.add(b_i) - for i in range(gs.length): - head = gs.heads[i] - if head in stack_words: - gs.state_bits[i] = set_state_flag( - gs.state_bits[i], + if not is_head_unknown(gs, s_i): + gs.n_kids_in_stack[gs.heads[s_i]] += 1 + for kid in gs.kids[s_i][:gs.n_kids[s_i]]: + gs.state_bits[kid] = set_state_flag( + gs.state_bits[kid], HEAD_IN_STACK, 1 ) - gs.state_bits[i] = set_state_flag( - gs.state_bits[i], - HEAD_IN_BUFFER, - 0 - ) - elif head in buffer_words: - gs.state_bits[i] = set_state_flag( - gs.state_bits[i], - HEAD_IN_STACK, - 0 - ) - gs.state_bits[i] = set_state_flag( - gs.state_bits[i], + for i in range(stcls.buffer_length()): + b_i = stcls.B(i) + if not is_head_unknown(gs, b_i): + gs.n_kids_in_buffer[gs.heads[b_i]] += 1 + for kid in gs.kids[b_i][:gs.n_kids[b_i]]: + gs.state_bits[kid] = set_state_flag( + gs.state_bits[kid], HEAD_IN_BUFFER, 1 ) @@ -228,13 +194,18 @@ cdef class ArcEagerGold: def __init__(self, ArcEager moves, StateClass stcls, Example example): self.mem = Pool() - self.c = create_gold_state(self.mem, stcls, example) + heads, labels = example.get_aligned_parse(projectivize=True) + labels = [label if label is not None else "" for label in labels] + labels = [example.x.vocab.strings.add(label) for label in labels] + sent_starts = example.get_aligned("SENT_START") + assert len(heads) == len(labels) == len(sent_starts) + self.c = create_gold_state(self.mem, stcls, heads, labels, sent_starts) + self.update(stcls) def update(self, StateClass stcls): update_gold_state(&self.c, stcls) - cdef int check_state_gold(char state_bits, char flag) nogil: cdef char one = 1 return state_bits & (one << flag)