mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-11 00:50:33 +03:00
Improve efficiency of ArEager oracle
This commit is contained in:
parent
537a5b9cef
commit
d5212f7ba8
|
@ -1,6 +1,6 @@
|
||||||
# cython: profile=True, cdivision=True, infer_types=True
|
# cython: profile=True, cdivision=True, infer_types=True
|
||||||
from cpython.ref cimport Py_INCREF
|
from cpython.ref cimport Py_INCREF
|
||||||
from cymem.cymem cimport Pool
|
from cymem.cymem cimport Pool, Address
|
||||||
from libc.stdint cimport int32_t
|
from libc.stdint cimport int32_t
|
||||||
|
|
||||||
from collections import defaultdict, Counter
|
from collections import defaultdict, Counter
|
||||||
|
@ -59,26 +59,28 @@ cdef enum:
|
||||||
|
|
||||||
cdef struct GoldParseStateC:
|
cdef struct GoldParseStateC:
|
||||||
char* state_bits
|
char* state_bits
|
||||||
attr_t* labels
|
|
||||||
int32_t* heads
|
|
||||||
int32_t* n_kids_in_buffer
|
int32_t* n_kids_in_buffer
|
||||||
int32_t* n_kids_in_stack
|
int32_t* n_kids_in_stack
|
||||||
|
int32_t* heads
|
||||||
|
attr_t* labels
|
||||||
|
int32_t** kids
|
||||||
|
int32_t* n_kids
|
||||||
int32_t length
|
int32_t length
|
||||||
int32_t stride
|
int32_t stride
|
||||||
|
|
||||||
|
|
||||||
cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example example) except *:
|
cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls,
|
||||||
|
heads, labels, sent_starts) except *:
|
||||||
cdef GoldParseStateC gs
|
cdef GoldParseStateC gs
|
||||||
gs.length = example.x.length
|
gs.length = len(heads)
|
||||||
gs.stride = 1
|
gs.stride = 1
|
||||||
gs.state_bits = <char*>mem.alloc(gs.length, sizeof(gs.state_bits[0]))
|
|
||||||
gs.labels = <attr_t*>mem.alloc(gs.length, sizeof(gs.labels[0]))
|
gs.labels = <attr_t*>mem.alloc(gs.length, sizeof(gs.labels[0]))
|
||||||
gs.heads = <int32_t*>mem.alloc(gs.length, sizeof(gs.heads[0]))
|
gs.heads = <int32_t*>mem.alloc(gs.length, sizeof(gs.heads[0]))
|
||||||
|
gs.n_kids = <int32_t*>mem.alloc(gs.length, sizeof(gs.n_kids[0]))
|
||||||
|
gs.state_bits = <char*>mem.alloc(gs.length, sizeof(gs.state_bits[0]))
|
||||||
gs.n_kids_in_buffer = <int32_t*>mem.alloc(gs.length, sizeof(gs.n_kids_in_buffer[0]))
|
gs.n_kids_in_buffer = <int32_t*>mem.alloc(gs.length, sizeof(gs.n_kids_in_buffer[0]))
|
||||||
gs.n_kids_in_stack = <int32_t*>mem.alloc(gs.length, sizeof(gs.n_kids_in_stack[0]))
|
gs.n_kids_in_stack = <int32_t*>mem.alloc(gs.length, sizeof(gs.n_kids_in_stack[0]))
|
||||||
|
|
||||||
heads, labels = example.get_aligned_parse(projectivize=True)
|
|
||||||
sent_starts = example.get_aligned("SENT_START")
|
|
||||||
for i, is_sent_start in enumerate(sent_starts):
|
for i, is_sent_start in enumerate(sent_starts):
|
||||||
if is_sent_start == True:
|
if is_sent_start == True:
|
||||||
gs.state_bits[i] = set_state_flag(
|
gs.state_bits[i] = set_state_flag(
|
||||||
|
@ -115,11 +117,12 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp
|
||||||
0
|
0
|
||||||
)
|
)
|
||||||
|
|
||||||
cdef TokenC ref_tok
|
|
||||||
for i, (head, label) in enumerate(zip(heads, labels)):
|
for i, (head, label) in enumerate(zip(heads, labels)):
|
||||||
if head is not None:
|
if head is not None:
|
||||||
gs.heads[i] = head
|
gs.heads[i] = head
|
||||||
gs.labels[i] = example.x.vocab.strings.add(label)
|
gs.labels[i] = label
|
||||||
|
if i != head:
|
||||||
|
gs.n_kids[head] += 1
|
||||||
gs.state_bits[i] = set_state_flag(
|
gs.state_bits[i] = set_state_flag(
|
||||||
gs.state_bits[i],
|
gs.state_bits[i],
|
||||||
HEAD_UNKNOWN,
|
HEAD_UNKNOWN,
|
||||||
|
@ -131,46 +134,24 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp
|
||||||
HEAD_UNKNOWN,
|
HEAD_UNKNOWN,
|
||||||
1
|
1
|
||||||
)
|
)
|
||||||
stack_words = set()
|
# Make an array of pointers, pointing into the gs_kids_flat array.
|
||||||
for i in range(stcls.stack_depth()):
|
gs.kids = <int32_t**>mem.alloc(gs.length, sizeof(int32_t*))
|
||||||
s_i = stcls.S(i)
|
|
||||||
head = gs.heads[s_i]
|
|
||||||
gs.n_kids_in_stack[head] += 1
|
|
||||||
stack_words.add(s_i)
|
|
||||||
buffer_words = set()
|
|
||||||
for i in range(stcls.buffer_length()):
|
|
||||||
b_i = stcls.B(i)
|
|
||||||
head = gs.heads[b_i]
|
|
||||||
gs.n_kids_in_buffer[head] += 1
|
|
||||||
buffer_words.add(b_i)
|
|
||||||
for i in range(gs.length):
|
for i in range(gs.length):
|
||||||
head = gs.heads[i]
|
if gs.n_kids[i] != 0:
|
||||||
if head in stack_words:
|
gs.kids[i] = <int32_t*>mem.alloc(gs.n_kids[i], sizeof(int32_t))
|
||||||
gs.state_bits[i] = set_state_flag(
|
# This is a temporary buffer
|
||||||
gs.state_bits[i],
|
js_addr = Address(gs.length, sizeof(int32_t))
|
||||||
HEAD_IN_STACK,
|
js = <int32_t*>js_addr.ptr
|
||||||
1
|
for i in range(gs.length):
|
||||||
)
|
if not is_head_unknown(&gs, i):
|
||||||
gs.state_bits[i] = set_state_flag(
|
head = gs.heads[i]
|
||||||
gs.state_bits[i],
|
if head != i:
|
||||||
HEAD_IN_BUFFER,
|
gs.kids[head][js[head]] = i
|
||||||
0
|
js[head] += 1
|
||||||
)
|
|
||||||
elif head in buffer_words:
|
|
||||||
gs.state_bits[i] = set_state_flag(
|
|
||||||
gs.state_bits[i],
|
|
||||||
HEAD_IN_STACK,
|
|
||||||
0
|
|
||||||
)
|
|
||||||
gs.state_bits[i] = set_state_flag(
|
|
||||||
gs.state_bits[i],
|
|
||||||
HEAD_IN_BUFFER,
|
|
||||||
1
|
|
||||||
)
|
|
||||||
return gs
|
return gs
|
||||||
|
|
||||||
|
|
||||||
cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) except *:
|
cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) nogil:
|
||||||
for i in range(gs.length):
|
for i in range(gs.length):
|
||||||
gs.state_bits[i] = set_state_flag(
|
gs.state_bits[i] = set_state_flag(
|
||||||
gs.state_bits[i],
|
gs.state_bits[i],
|
||||||
|
@ -184,39 +165,24 @@ cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) except *:
|
||||||
)
|
)
|
||||||
gs.n_kids_in_stack[i] = 0
|
gs.n_kids_in_stack[i] = 0
|
||||||
gs.n_kids_in_buffer[i] = 0
|
gs.n_kids_in_buffer[i] = 0
|
||||||
stack_words = set()
|
|
||||||
for i in range(stcls.stack_depth()):
|
for i in range(stcls.stack_depth()):
|
||||||
s_i = stcls.S(i)
|
s_i = stcls.S(i)
|
||||||
head = gs.heads[s_i]
|
if not is_head_unknown(gs, s_i):
|
||||||
gs.n_kids_in_stack[head] += 1
|
gs.n_kids_in_stack[gs.heads[s_i]] += 1
|
||||||
stack_words.add(s_i)
|
for kid in gs.kids[s_i][:gs.n_kids[s_i]]:
|
||||||
buffer_words = set()
|
gs.state_bits[kid] = set_state_flag(
|
||||||
for i in range(stcls.buffer_length()):
|
gs.state_bits[kid],
|
||||||
b_i = stcls.B(i)
|
|
||||||
head = gs.heads[b_i]
|
|
||||||
gs.n_kids_in_buffer[head] += 1
|
|
||||||
buffer_words.add(b_i)
|
|
||||||
for i in range(gs.length):
|
|
||||||
head = gs.heads[i]
|
|
||||||
if head in stack_words:
|
|
||||||
gs.state_bits[i] = set_state_flag(
|
|
||||||
gs.state_bits[i],
|
|
||||||
HEAD_IN_STACK,
|
HEAD_IN_STACK,
|
||||||
1
|
1
|
||||||
)
|
)
|
||||||
gs.state_bits[i] = set_state_flag(
|
for i in range(stcls.buffer_length()):
|
||||||
gs.state_bits[i],
|
b_i = stcls.B(i)
|
||||||
HEAD_IN_BUFFER,
|
if not is_head_unknown(gs, b_i):
|
||||||
0
|
gs.n_kids_in_buffer[gs.heads[b_i]] += 1
|
||||||
)
|
for kid in gs.kids[b_i][:gs.n_kids[b_i]]:
|
||||||
elif head in buffer_words:
|
gs.state_bits[kid] = set_state_flag(
|
||||||
gs.state_bits[i] = set_state_flag(
|
gs.state_bits[kid],
|
||||||
gs.state_bits[i],
|
|
||||||
HEAD_IN_STACK,
|
|
||||||
0
|
|
||||||
)
|
|
||||||
gs.state_bits[i] = set_state_flag(
|
|
||||||
gs.state_bits[i],
|
|
||||||
HEAD_IN_BUFFER,
|
HEAD_IN_BUFFER,
|
||||||
1
|
1
|
||||||
)
|
)
|
||||||
|
@ -228,13 +194,18 @@ cdef class ArcEagerGold:
|
||||||
|
|
||||||
def __init__(self, ArcEager moves, StateClass stcls, Example example):
|
def __init__(self, ArcEager moves, StateClass stcls, Example example):
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
self.c = create_gold_state(self.mem, stcls, example)
|
heads, labels = example.get_aligned_parse(projectivize=True)
|
||||||
|
labels = [label if label is not None else "" for label in labels]
|
||||||
|
labels = [example.x.vocab.strings.add(label) for label in labels]
|
||||||
|
sent_starts = example.get_aligned("SENT_START")
|
||||||
|
assert len(heads) == len(labels) == len(sent_starts)
|
||||||
|
self.c = create_gold_state(self.mem, stcls, heads, labels, sent_starts)
|
||||||
|
self.update(stcls)
|
||||||
|
|
||||||
def update(self, StateClass stcls):
|
def update(self, StateClass stcls):
|
||||||
update_gold_state(&self.c, stcls)
|
update_gold_state(&self.c, stcls)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
cdef int check_state_gold(char state_bits, char flag) nogil:
|
cdef int check_state_gold(char state_bits, char flag) nogil:
|
||||||
cdef char one = 1
|
cdef char one = 1
|
||||||
return state_bits & (one << flag)
|
return state_bits & (one << flag)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user