Improve efficiency of ArEager oracle

This commit is contained in:
Matthew Honnibal 2020-06-23 15:55:12 +02:00
parent 537a5b9cef
commit d5212f7ba8

View File

@ -1,6 +1,6 @@
# cython: profile=True, cdivision=True, infer_types=True
from cpython.ref cimport Py_INCREF
from cymem.cymem cimport Pool
from cymem.cymem cimport Pool, Address
from libc.stdint cimport int32_t
from collections import defaultdict, Counter
@ -59,26 +59,28 @@ cdef enum:
cdef struct GoldParseStateC:
char* state_bits
attr_t* labels
int32_t* heads
int32_t* n_kids_in_buffer
int32_t* n_kids_in_stack
int32_t* heads
attr_t* labels
int32_t** kids
int32_t* n_kids
int32_t length
int32_t stride
cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example example) except *:
cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls,
heads, labels, sent_starts) except *:
cdef GoldParseStateC gs
gs.length = example.x.length
gs.length = len(heads)
gs.stride = 1
gs.state_bits = <char*>mem.alloc(gs.length, sizeof(gs.state_bits[0]))
gs.labels = <attr_t*>mem.alloc(gs.length, sizeof(gs.labels[0]))
gs.heads = <int32_t*>mem.alloc(gs.length, sizeof(gs.heads[0]))
gs.n_kids = <int32_t*>mem.alloc(gs.length, sizeof(gs.n_kids[0]))
gs.state_bits = <char*>mem.alloc(gs.length, sizeof(gs.state_bits[0]))
gs.n_kids_in_buffer = <int32_t*>mem.alloc(gs.length, sizeof(gs.n_kids_in_buffer[0]))
gs.n_kids_in_stack = <int32_t*>mem.alloc(gs.length, sizeof(gs.n_kids_in_stack[0]))
heads, labels = example.get_aligned_parse(projectivize=True)
sent_starts = example.get_aligned("SENT_START")
for i, is_sent_start in enumerate(sent_starts):
if is_sent_start == True:
gs.state_bits[i] = set_state_flag(
@ -115,11 +117,12 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp
0
)
cdef TokenC ref_tok
for i, (head, label) in enumerate(zip(heads, labels)):
if head is not None:
gs.heads[i] = head
gs.labels[i] = example.x.vocab.strings.add(label)
gs.labels[i] = label
if i != head:
gs.n_kids[head] += 1
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
HEAD_UNKNOWN,
@ -131,46 +134,24 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp
HEAD_UNKNOWN,
1
)
stack_words = set()
for i in range(stcls.stack_depth()):
s_i = stcls.S(i)
head = gs.heads[s_i]
gs.n_kids_in_stack[head] += 1
stack_words.add(s_i)
buffer_words = set()
for i in range(stcls.buffer_length()):
b_i = stcls.B(i)
head = gs.heads[b_i]
gs.n_kids_in_buffer[head] += 1
buffer_words.add(b_i)
# Make an array of pointers, pointing into the gs_kids_flat array.
gs.kids = <int32_t**>mem.alloc(gs.length, sizeof(int32_t*))
for i in range(gs.length):
head = gs.heads[i]
if head in stack_words:
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
HEAD_IN_STACK,
1
)
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
HEAD_IN_BUFFER,
0
)
elif head in buffer_words:
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
HEAD_IN_STACK,
0
)
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
HEAD_IN_BUFFER,
1
)
if gs.n_kids[i] != 0:
gs.kids[i] = <int32_t*>mem.alloc(gs.n_kids[i], sizeof(int32_t))
# This is a temporary buffer
js_addr = Address(gs.length, sizeof(int32_t))
js = <int32_t*>js_addr.ptr
for i in range(gs.length):
if not is_head_unknown(&gs, i):
head = gs.heads[i]
if head != i:
gs.kids[head][js[head]] = i
js[head] += 1
return gs
cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) except *:
cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) nogil:
for i in range(gs.length):
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
@ -184,39 +165,24 @@ cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) except *:
)
gs.n_kids_in_stack[i] = 0
gs.n_kids_in_buffer[i] = 0
stack_words = set()
for i in range(stcls.stack_depth()):
s_i = stcls.S(i)
head = gs.heads[s_i]
gs.n_kids_in_stack[head] += 1
stack_words.add(s_i)
buffer_words = set()
for i in range(stcls.buffer_length()):
b_i = stcls.B(i)
head = gs.heads[b_i]
gs.n_kids_in_buffer[head] += 1
buffer_words.add(b_i)
for i in range(gs.length):
head = gs.heads[i]
if head in stack_words:
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
if not is_head_unknown(gs, s_i):
gs.n_kids_in_stack[gs.heads[s_i]] += 1
for kid in gs.kids[s_i][:gs.n_kids[s_i]]:
gs.state_bits[kid] = set_state_flag(
gs.state_bits[kid],
HEAD_IN_STACK,
1
)
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
HEAD_IN_BUFFER,
0
)
elif head in buffer_words:
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
HEAD_IN_STACK,
0
)
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
for i in range(stcls.buffer_length()):
b_i = stcls.B(i)
if not is_head_unknown(gs, b_i):
gs.n_kids_in_buffer[gs.heads[b_i]] += 1
for kid in gs.kids[b_i][:gs.n_kids[b_i]]:
gs.state_bits[kid] = set_state_flag(
gs.state_bits[kid],
HEAD_IN_BUFFER,
1
)
@ -228,13 +194,18 @@ cdef class ArcEagerGold:
def __init__(self, ArcEager moves, StateClass stcls, Example example):
self.mem = Pool()
self.c = create_gold_state(self.mem, stcls, example)
heads, labels = example.get_aligned_parse(projectivize=True)
labels = [label if label is not None else "" for label in labels]
labels = [example.x.vocab.strings.add(label) for label in labels]
sent_starts = example.get_aligned("SENT_START")
assert len(heads) == len(labels) == len(sent_starts)
self.c = create_gold_state(self.mem, stcls, heads, labels, sent_starts)
self.update(stcls)
def update(self, StateClass stcls):
update_gold_state(&self.c, stcls)
cdef int check_state_gold(char state_bits, char flag) nogil:
cdef char one = 1
return state_bits & (one << flag)