Merge pull request #10048 from danieldk/index-arcs-by-head

Use constant-time head lookups in StateC::{L,R}
This commit is contained in:
Daniël de Kok 2022-01-20 13:06:14 +01:00 committed by GitHub
commit 6984f55277
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3,6 +3,7 @@ from libc.string cimport memcpy, memset
from libc.stdlib cimport calloc, free
from libc.stdint cimport uint32_t, uint64_t
cimport libcpp
from libcpp.unordered_map cimport unordered_map
from libcpp.vector cimport vector
from libcpp.set cimport set
from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno
@ -30,8 +31,8 @@ cdef cppclass StateC:
vector[int] _stack
vector[int] _rebuffer
vector[SpanC] _ents
vector[ArcC] _left_arcs
vector[ArcC] _right_arcs
unordered_map[int, vector[ArcC]] _left_arcs
unordered_map[int, vector[ArcC]] _right_arcs
vector[libcpp.bool] _unshiftable
set[int] _sent_starts
TokenC _empty_token
@ -160,15 +161,22 @@ cdef cppclass StateC:
else:
return &this._sent[i]
void get_arcs(vector[ArcC]* arcs) nogil const:
for i in range(this._left_arcs.size()):
arc = this._left_arcs.at(i)
void map_get_arcs(const unordered_map[int, vector[ArcC]] &heads_arcs, vector[ArcC]* out) nogil const:
cdef const vector[ArcC]* arcs
head_arcs_it = heads_arcs.const_begin()
while head_arcs_it != heads_arcs.const_end():
arcs = &deref(head_arcs_it).second
arcs_it = arcs.const_begin()
while arcs_it != arcs.const_end():
arc = deref(arcs_it)
if arc.head != -1 and arc.child != -1:
arcs.push_back(arc)
for i in range(this._right_arcs.size()):
arc = this._right_arcs.at(i)
if arc.head != -1 and arc.child != -1:
arcs.push_back(arc)
out.push_back(arc)
incr(arcs_it)
incr(head_arcs_it)
void get_arcs(vector[ArcC]* out) nogil const:
this.map_get_arcs(this._left_arcs, out)
this.map_get_arcs(this._right_arcs, out)
int H(int child) nogil const:
if child >= this.length or child < 0:
@ -182,37 +190,35 @@ cdef cppclass StateC:
else:
return this._ents.back().start
int L(int head, int idx) nogil const:
if idx < 1 or this._left_arcs.size() == 0:
int nth_child(const unordered_map[int, vector[ArcC]]& heads_arcs, int head, int idx) nogil const:
if idx < 1:
return -1
# Work backwards through left-arcs to find the arc at the
head_arcs_it = heads_arcs.const_find(head)
if head_arcs_it == heads_arcs.const_end():
return -1
cdef const vector[ArcC]* arcs = &deref(head_arcs_it).second
# Work backwards through arcs to find the arc at the
# requested index more quickly.
cdef size_t child_index = 0
it = this._left_arcs.const_rbegin()
while it != this._left_arcs.rend():
arc = deref(it)
if arc.head == head and arc.child != -1 and arc.child < head:
arcs_it = arcs.const_rbegin()
while arcs_it != arcs.const_rend() and child_index != idx:
arc = deref(arcs_it)
if arc.child != -1:
child_index += 1
if child_index == idx:
return arc.child
incr(it)
incr(arcs_it)
return -1
int L(int head, int idx) nogil const:
return this.nth_child(this._left_arcs, head, idx)
int R(int head, int idx) nogil const:
if idx < 1 or this._right_arcs.size() == 0:
return -1
cdef vector[int] rights
for i in range(this._right_arcs.size()):
arc = this._right_arcs.at(i)
if arc.head == head and arc.child != -1 and arc.child > head:
rights.push_back(arc.child)
idx = (<int>rights.size()) - idx
if idx < 0:
return -1
else:
return rights.at(idx)
return this.nth_child(this._right_arcs, head, idx)
bint empty() nogil const:
return this._stack.size() == 0
@ -254,22 +260,29 @@ cdef cppclass StateC:
int r_edge(int word) nogil const:
return word
int n_L(int head) nogil const:
int n_arcs(const unordered_map[int, vector[ArcC]] &heads_arcs, int head) nogil const:
cdef int n = 0
for i in range(this._left_arcs.size()):
arc = this._left_arcs.at(i)
if arc.head == head and arc.child != -1 and arc.child < arc.head:
n += 1
head_arcs_it = heads_arcs.const_find(head)
if head_arcs_it == heads_arcs.const_end():
return n
int n_R(int head) nogil const:
cdef int n = 0
for i in range(this._right_arcs.size()):
arc = this._right_arcs.at(i)
if arc.head == head and arc.child != -1 and arc.child > arc.head:
cdef const vector[ArcC]* arcs = &deref(head_arcs_it).second
arcs_it = arcs.const_begin()
while arcs_it != arcs.end():
arc = deref(arcs_it)
if arc.child != -1:
n += 1
incr(arcs_it)
return n
int n_L(int head) nogil const:
return n_arcs(this._left_arcs, head)
int n_R(int head) nogil const:
return n_arcs(this._right_arcs, head)
bint stack_is_connected() nogil const:
return False
@ -328,19 +341,20 @@ cdef cppclass StateC:
arc.child = child
arc.label = label
if head > child:
this._left_arcs.push_back(arc)
this._left_arcs[arc.head].push_back(arc)
else:
this._right_arcs.push_back(arc)
this._right_arcs[arc.head].push_back(arc)
this._heads[child] = head
void del_arc(int h_i, int c_i) nogil:
cdef vector[ArcC]* arcs
if h_i > c_i:
arcs = &this._left_arcs
else:
arcs = &this._right_arcs
void map_del_arc(unordered_map[int, vector[ArcC]]* heads_arcs, int h_i, int c_i) nogil:
arcs_it = heads_arcs.find(h_i)
if arcs_it == heads_arcs.end():
return
arcs = &deref(arcs_it).second
if arcs.size() == 0:
return
arc = arcs.back()
if arc.head == h_i and arc.child == c_i:
arcs.pop_back()
@ -353,6 +367,12 @@ cdef cppclass StateC:
arc.label = 0
break
void del_arc(int h_i, int c_i) nogil:
if h_i > c_i:
this.map_del_arc(&this._left_arcs, h_i, c_i)
else:
this.map_del_arc(&this._right_arcs, h_i, c_i)
SpanC get_ent() nogil const:
cdef SpanC ent
if this._ents.size() == 0: