mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Fix significant performance bug in parser training (#6010)
The parser training makes use of a trick for long documents, where we use the oracle to cut up the document into sections, so that we can have batch items in the middle of a document. For instance, if we have one document of 600 words, we might make 6 states, starting at words 0, 100, 200, 300, 400 and 500. The problem is for v3, I screwed this up and didn't stop parsing! So instead of a batch of [100, 100, 100, 100, 100, 100], we'd have a batch of [600, 500, 400, 300, 200, 100]. Oops. The implementation here could probably be improved, it's annoying to have this extra variable in the state. But this'll do. This makes the v3 parser training 5-10 times faster, depending on document lengths. This problem wasn't in v2.
This commit is contained in:
parent
6bfb1b3a29
commit
c1bf3a5602
|
@ -42,6 +42,7 @@ cdef cppclass StateC:
|
||||||
RingBufferC _hist
|
RingBufferC _hist
|
||||||
int length
|
int length
|
||||||
int offset
|
int offset
|
||||||
|
int n_pushes
|
||||||
int _s_i
|
int _s_i
|
||||||
int _b_i
|
int _b_i
|
||||||
int _e_i
|
int _e_i
|
||||||
|
@ -49,6 +50,7 @@ cdef cppclass StateC:
|
||||||
|
|
||||||
__init__(const TokenC* sent, int length) nogil:
|
__init__(const TokenC* sent, int length) nogil:
|
||||||
cdef int PADDING = 5
|
cdef int PADDING = 5
|
||||||
|
this.n_pushes = 0
|
||||||
this._buffer = <int*>calloc(length + (PADDING * 2), sizeof(int))
|
this._buffer = <int*>calloc(length + (PADDING * 2), sizeof(int))
|
||||||
this._stack = <int*>calloc(length + (PADDING * 2), sizeof(int))
|
this._stack = <int*>calloc(length + (PADDING * 2), sizeof(int))
|
||||||
this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint))
|
this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint))
|
||||||
|
@ -335,6 +337,7 @@ cdef cppclass StateC:
|
||||||
this.set_break(this.B_(0).l_edge)
|
this.set_break(this.B_(0).l_edge)
|
||||||
if this._b_i > this._break:
|
if this._b_i > this._break:
|
||||||
this._break = -1
|
this._break = -1
|
||||||
|
this.n_pushes += 1
|
||||||
|
|
||||||
void pop() nogil:
|
void pop() nogil:
|
||||||
if this._s_i >= 1:
|
if this._s_i >= 1:
|
||||||
|
@ -351,6 +354,7 @@ cdef cppclass StateC:
|
||||||
this._buffer[this._b_i] = this.S(0)
|
this._buffer[this._b_i] = this.S(0)
|
||||||
this._s_i -= 1
|
this._s_i -= 1
|
||||||
this.shifted[this.B(0)] = True
|
this.shifted[this.B(0)] = True
|
||||||
|
this.n_pushes -= 1
|
||||||
|
|
||||||
void add_arc(int head, int child, attr_t label) nogil:
|
void add_arc(int head, int child, attr_t label) nogil:
|
||||||
if this.has_head(child):
|
if this.has_head(child):
|
||||||
|
@ -431,6 +435,7 @@ cdef cppclass StateC:
|
||||||
this._break = src._break
|
this._break = src._break
|
||||||
this.offset = src.offset
|
this.offset = src.offset
|
||||||
this._empty_token = src._empty_token
|
this._empty_token = src._empty_token
|
||||||
|
this.n_pushes = src.n_pushes
|
||||||
|
|
||||||
void fast_forward() nogil:
|
void fast_forward() nogil:
|
||||||
# space token attachement policy:
|
# space token attachement policy:
|
||||||
|
|
|
@ -36,6 +36,10 @@ cdef class StateClass:
|
||||||
hist[i] = self.c.get_hist(i+1)
|
hist[i] = self.c.get_hist(i+1)
|
||||||
return hist
|
return hist
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_pushes(self):
|
||||||
|
return self.c.n_pushes
|
||||||
|
|
||||||
def is_final(self):
|
def is_final(self):
|
||||||
return self.c.is_final()
|
return self.c.is_final()
|
||||||
|
|
||||||
|
|
|
@ -279,14 +279,14 @@ cdef class Parser(Pipe):
|
||||||
# Chop sequences into lengths of this many transitions, to make the
|
# Chop sequences into lengths of this many transitions, to make the
|
||||||
# batch uniform length.
|
# batch uniform length.
|
||||||
# We used to randomize this, but it's not clear that actually helps?
|
# We used to randomize this, but it's not clear that actually helps?
|
||||||
cut_size = self.cfg["update_with_oracle_cut_size"]
|
max_pushes = self.cfg["update_with_oracle_cut_size"]
|
||||||
states, golds, max_steps = self._init_gold_batch(
|
states, golds, _ = self._init_gold_batch(
|
||||||
examples,
|
examples,
|
||||||
max_length=cut_size
|
max_length=max_pushes
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
states, golds, _ = self.moves.init_gold_batch(examples)
|
states, golds, _ = self.moves.init_gold_batch(examples)
|
||||||
max_steps = max([len(eg.x) for eg in examples])
|
max_pushes = max([len(eg.x) for eg in examples])
|
||||||
if not states:
|
if not states:
|
||||||
return losses
|
return losses
|
||||||
all_states = list(states)
|
all_states = list(states)
|
||||||
|
@ -302,7 +302,8 @@ cdef class Parser(Pipe):
|
||||||
backprop(d_scores)
|
backprop(d_scores)
|
||||||
# Follow the predicted action
|
# Follow the predicted action
|
||||||
self.transition_states(states, scores)
|
self.transition_states(states, scores)
|
||||||
states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()]
|
states_golds = [(s, g) for (s, g) in zip(states, golds)
|
||||||
|
if s.n_pushes < max_pushes and not s.is_final()]
|
||||||
|
|
||||||
backprop_tok2vec(golds)
|
backprop_tok2vec(golds)
|
||||||
if sgd not in (None, False):
|
if sgd not in (None, False):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user