mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-10 16:40:34 +03:00
Improve efficiency of ArEager oracle
This commit is contained in:
parent
537a5b9cef
commit
d5212f7ba8
|
@ -1,6 +1,6 @@
|
|||
# cython: profile=True, cdivision=True, infer_types=True
|
||||
from cpython.ref cimport Py_INCREF
|
||||
from cymem.cymem cimport Pool
|
||||
from cymem.cymem cimport Pool, Address
|
||||
from libc.stdint cimport int32_t
|
||||
|
||||
from collections import defaultdict, Counter
|
||||
|
@ -59,26 +59,28 @@ cdef enum:
|
|||
|
||||
cdef struct GoldParseStateC:
|
||||
char* state_bits
|
||||
attr_t* labels
|
||||
int32_t* heads
|
||||
int32_t* n_kids_in_buffer
|
||||
int32_t* n_kids_in_stack
|
||||
int32_t* heads
|
||||
attr_t* labels
|
||||
int32_t** kids
|
||||
int32_t* n_kids
|
||||
int32_t length
|
||||
int32_t stride
|
||||
|
||||
|
||||
cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example example) except *:
|
||||
cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls,
|
||||
heads, labels, sent_starts) except *:
|
||||
cdef GoldParseStateC gs
|
||||
gs.length = example.x.length
|
||||
gs.length = len(heads)
|
||||
gs.stride = 1
|
||||
gs.state_bits = <char*>mem.alloc(gs.length, sizeof(gs.state_bits[0]))
|
||||
gs.labels = <attr_t*>mem.alloc(gs.length, sizeof(gs.labels[0]))
|
||||
gs.heads = <int32_t*>mem.alloc(gs.length, sizeof(gs.heads[0]))
|
||||
gs.n_kids = <int32_t*>mem.alloc(gs.length, sizeof(gs.n_kids[0]))
|
||||
gs.state_bits = <char*>mem.alloc(gs.length, sizeof(gs.state_bits[0]))
|
||||
gs.n_kids_in_buffer = <int32_t*>mem.alloc(gs.length, sizeof(gs.n_kids_in_buffer[0]))
|
||||
gs.n_kids_in_stack = <int32_t*>mem.alloc(gs.length, sizeof(gs.n_kids_in_stack[0]))
|
||||
|
||||
heads, labels = example.get_aligned_parse(projectivize=True)
|
||||
sent_starts = example.get_aligned("SENT_START")
|
||||
for i, is_sent_start in enumerate(sent_starts):
|
||||
if is_sent_start == True:
|
||||
gs.state_bits[i] = set_state_flag(
|
||||
|
@ -115,11 +117,12 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp
|
|||
0
|
||||
)
|
||||
|
||||
cdef TokenC ref_tok
|
||||
for i, (head, label) in enumerate(zip(heads, labels)):
|
||||
if head is not None:
|
||||
gs.heads[i] = head
|
||||
gs.labels[i] = example.x.vocab.strings.add(label)
|
||||
gs.labels[i] = label
|
||||
if i != head:
|
||||
gs.n_kids[head] += 1
|
||||
gs.state_bits[i] = set_state_flag(
|
||||
gs.state_bits[i],
|
||||
HEAD_UNKNOWN,
|
||||
|
@ -131,46 +134,24 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp
|
|||
HEAD_UNKNOWN,
|
||||
1
|
||||
)
|
||||
stack_words = set()
|
||||
for i in range(stcls.stack_depth()):
|
||||
s_i = stcls.S(i)
|
||||
head = gs.heads[s_i]
|
||||
gs.n_kids_in_stack[head] += 1
|
||||
stack_words.add(s_i)
|
||||
buffer_words = set()
|
||||
for i in range(stcls.buffer_length()):
|
||||
b_i = stcls.B(i)
|
||||
head = gs.heads[b_i]
|
||||
gs.n_kids_in_buffer[head] += 1
|
||||
buffer_words.add(b_i)
|
||||
# Make an array of pointers, pointing into the gs_kids_flat array.
|
||||
gs.kids = <int32_t**>mem.alloc(gs.length, sizeof(int32_t*))
|
||||
for i in range(gs.length):
|
||||
head = gs.heads[i]
|
||||
if head in stack_words:
|
||||
gs.state_bits[i] = set_state_flag(
|
||||
gs.state_bits[i],
|
||||
HEAD_IN_STACK,
|
||||
1
|
||||
)
|
||||
gs.state_bits[i] = set_state_flag(
|
||||
gs.state_bits[i],
|
||||
HEAD_IN_BUFFER,
|
||||
0
|
||||
)
|
||||
elif head in buffer_words:
|
||||
gs.state_bits[i] = set_state_flag(
|
||||
gs.state_bits[i],
|
||||
HEAD_IN_STACK,
|
||||
0
|
||||
)
|
||||
gs.state_bits[i] = set_state_flag(
|
||||
gs.state_bits[i],
|
||||
HEAD_IN_BUFFER,
|
||||
1
|
||||
)
|
||||
if gs.n_kids[i] != 0:
|
||||
gs.kids[i] = <int32_t*>mem.alloc(gs.n_kids[i], sizeof(int32_t))
|
||||
# This is a temporary buffer
|
||||
js_addr = Address(gs.length, sizeof(int32_t))
|
||||
js = <int32_t*>js_addr.ptr
|
||||
for i in range(gs.length):
|
||||
if not is_head_unknown(&gs, i):
|
||||
head = gs.heads[i]
|
||||
if head != i:
|
||||
gs.kids[head][js[head]] = i
|
||||
js[head] += 1
|
||||
return gs
|
||||
|
||||
|
||||
cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) except *:
|
||||
cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) nogil:
|
||||
for i in range(gs.length):
|
||||
gs.state_bits[i] = set_state_flag(
|
||||
gs.state_bits[i],
|
||||
|
@ -184,39 +165,24 @@ cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) except *:
|
|||
)
|
||||
gs.n_kids_in_stack[i] = 0
|
||||
gs.n_kids_in_buffer[i] = 0
|
||||
stack_words = set()
|
||||
|
||||
for i in range(stcls.stack_depth()):
|
||||
s_i = stcls.S(i)
|
||||
head = gs.heads[s_i]
|
||||
gs.n_kids_in_stack[head] += 1
|
||||
stack_words.add(s_i)
|
||||
buffer_words = set()
|
||||
for i in range(stcls.buffer_length()):
|
||||
b_i = stcls.B(i)
|
||||
head = gs.heads[b_i]
|
||||
gs.n_kids_in_buffer[head] += 1
|
||||
buffer_words.add(b_i)
|
||||
for i in range(gs.length):
|
||||
head = gs.heads[i]
|
||||
if head in stack_words:
|
||||
gs.state_bits[i] = set_state_flag(
|
||||
gs.state_bits[i],
|
||||
if not is_head_unknown(gs, s_i):
|
||||
gs.n_kids_in_stack[gs.heads[s_i]] += 1
|
||||
for kid in gs.kids[s_i][:gs.n_kids[s_i]]:
|
||||
gs.state_bits[kid] = set_state_flag(
|
||||
gs.state_bits[kid],
|
||||
HEAD_IN_STACK,
|
||||
1
|
||||
)
|
||||
gs.state_bits[i] = set_state_flag(
|
||||
gs.state_bits[i],
|
||||
HEAD_IN_BUFFER,
|
||||
0
|
||||
)
|
||||
elif head in buffer_words:
|
||||
gs.state_bits[i] = set_state_flag(
|
||||
gs.state_bits[i],
|
||||
HEAD_IN_STACK,
|
||||
0
|
||||
)
|
||||
gs.state_bits[i] = set_state_flag(
|
||||
gs.state_bits[i],
|
||||
for i in range(stcls.buffer_length()):
|
||||
b_i = stcls.B(i)
|
||||
if not is_head_unknown(gs, b_i):
|
||||
gs.n_kids_in_buffer[gs.heads[b_i]] += 1
|
||||
for kid in gs.kids[b_i][:gs.n_kids[b_i]]:
|
||||
gs.state_bits[kid] = set_state_flag(
|
||||
gs.state_bits[kid],
|
||||
HEAD_IN_BUFFER,
|
||||
1
|
||||
)
|
||||
|
@ -228,13 +194,18 @@ cdef class ArcEagerGold:
|
|||
|
||||
def __init__(self, ArcEager moves, StateClass stcls, Example example):
|
||||
self.mem = Pool()
|
||||
self.c = create_gold_state(self.mem, stcls, example)
|
||||
heads, labels = example.get_aligned_parse(projectivize=True)
|
||||
labels = [label if label is not None else "" for label in labels]
|
||||
labels = [example.x.vocab.strings.add(label) for label in labels]
|
||||
sent_starts = example.get_aligned("SENT_START")
|
||||
assert len(heads) == len(labels) == len(sent_starts)
|
||||
self.c = create_gold_state(self.mem, stcls, heads, labels, sent_starts)
|
||||
self.update(stcls)
|
||||
|
||||
def update(self, StateClass stcls):
|
||||
update_gold_state(&self.c, stcls)
|
||||
|
||||
|
||||
|
||||
cdef int check_state_gold(char state_bits, char flag) nogil:
|
||||
cdef char one = 1
|
||||
return state_bits & (one << flag)
|
||||
|
|
Loading…
Reference in New Issue
Block a user