Fix significant performance bug in parser training (#6010)

The parser training makes use of a trick for long documents, where we
use the oracle to cut up the document into sections, so that we can have
batch items in the middle of a document. For instance, if we have one
document of 600 words, we might make 6 states, starting at words 0, 100,
200, 300, 400 and 500.

The problem is for v3, I screwed this up and didn't stop parsing! So
instead of a batch of [100, 100, 100, 100, 100, 100], we'd have a batch
of [600, 500, 400, 300, 200, 100]. Oops.

The implementation here could probably be improved, it's annoying to
have this extra variable in the state. But this'll do.

This makes the v3 parser training 5-10 times faster, depending on document
lengths. This problem wasn't in v2.
This commit is contained in:
Matthew Honnibal 2020-09-02 12:57:13 +02:00 committed by GitHub
parent 6bfb1b3a29
commit c1bf3a5602
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 5 deletions

View File

@ -42,6 +42,7 @@ cdef cppclass StateC:
RingBufferC _hist RingBufferC _hist
int length int length
int offset int offset
int n_pushes
int _s_i int _s_i
int _b_i int _b_i
int _e_i int _e_i
@ -49,6 +50,7 @@ cdef cppclass StateC:
__init__(const TokenC* sent, int length) nogil: __init__(const TokenC* sent, int length) nogil:
cdef int PADDING = 5 cdef int PADDING = 5
this.n_pushes = 0
this._buffer = <int*>calloc(length + (PADDING * 2), sizeof(int)) this._buffer = <int*>calloc(length + (PADDING * 2), sizeof(int))
this._stack = <int*>calloc(length + (PADDING * 2), sizeof(int)) this._stack = <int*>calloc(length + (PADDING * 2), sizeof(int))
this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint)) this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint))
@ -335,6 +337,7 @@ cdef cppclass StateC:
this.set_break(this.B_(0).l_edge) this.set_break(this.B_(0).l_edge)
if this._b_i > this._break: if this._b_i > this._break:
this._break = -1 this._break = -1
this.n_pushes += 1
void pop() nogil: void pop() nogil:
if this._s_i >= 1: if this._s_i >= 1:
@ -351,6 +354,7 @@ cdef cppclass StateC:
this._buffer[this._b_i] = this.S(0) this._buffer[this._b_i] = this.S(0)
this._s_i -= 1 this._s_i -= 1
this.shifted[this.B(0)] = True this.shifted[this.B(0)] = True
this.n_pushes -= 1
void add_arc(int head, int child, attr_t label) nogil: void add_arc(int head, int child, attr_t label) nogil:
if this.has_head(child): if this.has_head(child):
@ -431,6 +435,7 @@ cdef cppclass StateC:
this._break = src._break this._break = src._break
this.offset = src.offset this.offset = src.offset
this._empty_token = src._empty_token this._empty_token = src._empty_token
this.n_pushes = src.n_pushes
void fast_forward() nogil: void fast_forward() nogil:
# space token attachement policy: # space token attachement policy:

View File

@ -36,6 +36,10 @@ cdef class StateClass:
hist[i] = self.c.get_hist(i+1) hist[i] = self.c.get_hist(i+1)
return hist return hist
@property
def n_pushes(self):
return self.c.n_pushes
def is_final(self): def is_final(self):
return self.c.is_final() return self.c.is_final()

View File

@ -279,14 +279,14 @@ cdef class Parser(Pipe):
# Chop sequences into lengths of this many transitions, to make the # Chop sequences into lengths of this many transitions, to make the
# batch uniform length. # batch uniform length.
# We used to randomize this, but it's not clear that actually helps? # We used to randomize this, but it's not clear that actually helps?
cut_size = self.cfg["update_with_oracle_cut_size"] max_pushes = self.cfg["update_with_oracle_cut_size"]
states, golds, max_steps = self._init_gold_batch( states, golds, _ = self._init_gold_batch(
examples, examples,
max_length=cut_size max_length=max_pushes
) )
else: else:
states, golds, _ = self.moves.init_gold_batch(examples) states, golds, _ = self.moves.init_gold_batch(examples)
max_steps = max([len(eg.x) for eg in examples]) max_pushes = max([len(eg.x) for eg in examples])
if not states: if not states:
return losses return losses
all_states = list(states) all_states = list(states)
@ -302,7 +302,8 @@ cdef class Parser(Pipe):
backprop(d_scores) backprop(d_scores)
# Follow the predicted action # Follow the predicted action
self.transition_states(states, scores) self.transition_states(states, scores)
states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()] states_golds = [(s, g) for (s, g) in zip(states, golds)
if s.n_pushes < max_pushes and not s.is_final()]
backprop_tok2vec(golds) backprop_tok2vec(golds)
if sgd not in (None, False): if sgd not in (None, False):