From c1bf3a5602db211caf3d25abcf3b09ca42aec0f7 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 2 Sep 2020 12:57:13 +0200 Subject: [PATCH] Fix significant performance bug in parser training (#6010) The parser training makes use of a trick for long documents, where we use the oracle to cut up the document into sections, so that we can have batch items in the middle of a document. For instance, if we have one document of 600 words, we might make 6 states, starting at words 0, 100, 200, 300, 400 and 500. The problem is for v3, I screwed this up and didn't stop parsing! So instead of a batch of [100, 100, 100, 100, 100, 100], we'd have a batch of [600, 500, 400, 300, 200, 100]. Oops. The implementation here could probably be improved, it's annoying to have this extra variable in the state. But this'll do. This makes the v3 parser training 5-10 times faster, depending on document lengths. This problem wasn't in v2. --- spacy/pipeline/_parser_internals/_state.pxd | 5 +++++ spacy/pipeline/_parser_internals/stateclass.pyx | 4 ++++ spacy/pipeline/transition_parser.pyx | 11 ++++++----- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/spacy/pipeline/_parser_internals/_state.pxd b/spacy/pipeline/_parser_internals/_state.pxd index 0d0dd8c05..d31430124 100644 --- a/spacy/pipeline/_parser_internals/_state.pxd +++ b/spacy/pipeline/_parser_internals/_state.pxd @@ -42,6 +42,7 @@ cdef cppclass StateC: RingBufferC _hist int length int offset + int n_pushes int _s_i int _b_i int _e_i @@ -49,6 +50,7 @@ cdef cppclass StateC: __init__(const TokenC* sent, int length) nogil: cdef int PADDING = 5 + this.n_pushes = 0 this._buffer = calloc(length + (PADDING * 2), sizeof(int)) this._stack = calloc(length + (PADDING * 2), sizeof(int)) this.shifted = calloc(length + (PADDING * 2), sizeof(bint)) @@ -335,6 +337,7 @@ cdef cppclass StateC: this.set_break(this.B_(0).l_edge) if this._b_i > this._break: this._break = -1 + this.n_pushes += 1 void pop() nogil: if this._s_i >= 1: @@ -351,6 +354,7 @@ cdef cppclass StateC: this._buffer[this._b_i] = this.S(0) this._s_i -= 1 this.shifted[this.B(0)] = True + this.n_pushes -= 1 void add_arc(int head, int child, attr_t label) nogil: if this.has_head(child): @@ -431,6 +435,7 @@ cdef cppclass StateC: this._break = src._break this.offset = src.offset this._empty_token = src._empty_token + this.n_pushes = src.n_pushes void fast_forward() nogil: # space token attachement policy: diff --git a/spacy/pipeline/_parser_internals/stateclass.pyx b/spacy/pipeline/_parser_internals/stateclass.pyx index 880cf6cc5..d59ade467 100644 --- a/spacy/pipeline/_parser_internals/stateclass.pyx +++ b/spacy/pipeline/_parser_internals/stateclass.pyx @@ -36,6 +36,10 @@ cdef class StateClass: hist[i] = self.c.get_hist(i+1) return hist + @property + def n_pushes(self): + return self.c.n_pushes + def is_final(self): return self.c.is_final() diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index 2eadfa6aa..2169b4c17 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -279,14 +279,14 @@ cdef class Parser(Pipe): # Chop sequences into lengths of this many transitions, to make the # batch uniform length. # We used to randomize this, but it's not clear that actually helps? - cut_size = self.cfg["update_with_oracle_cut_size"] - states, golds, max_steps = self._init_gold_batch( + max_pushes = self.cfg["update_with_oracle_cut_size"] + states, golds, _ = self._init_gold_batch( examples, - max_length=cut_size + max_length=max_pushes ) else: states, golds, _ = self.moves.init_gold_batch(examples) - max_steps = max([len(eg.x) for eg in examples]) + max_pushes = max([len(eg.x) for eg in examples]) if not states: return losses all_states = list(states) @@ -302,7 +302,8 @@ cdef class Parser(Pipe): backprop(d_scores) # Follow the predicted action self.transition_states(states, scores) - states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()] + states_golds = [(s, g) for (s, g) in zip(states, golds) + if s.n_pushes < max_pushes and not s.is_final()] backprop_tok2vec(golds) if sgd not in (None, False):