Improve implementation of fix #6010

Follow-ups to the parser efficiency fix.

* Avoid introducing new counter for number of pushes
* Base cut on number of transitions, keeping it more even
* Reintroduce the randomization we had in v2.
This commit is contained in:
Matthew Honnibal 2020-09-02 14:42:32 +02:00
parent eb56377799
commit 737a1408d9
3 changed files with 17 additions and 32 deletions

View File

@ -42,7 +42,6 @@ cdef cppclass StateC:
RingBufferC _hist RingBufferC _hist
int length int length
int offset int offset
int n_pushes
int _s_i int _s_i
int _b_i int _b_i
int _e_i int _e_i
@ -50,7 +49,6 @@ cdef cppclass StateC:
__init__(const TokenC* sent, int length) nogil: __init__(const TokenC* sent, int length) nogil:
cdef int PADDING = 5 cdef int PADDING = 5
this.n_pushes = 0
this._buffer = <int*>calloc(length + (PADDING * 2), sizeof(int)) this._buffer = <int*>calloc(length + (PADDING * 2), sizeof(int))
this._stack = <int*>calloc(length + (PADDING * 2), sizeof(int)) this._stack = <int*>calloc(length + (PADDING * 2), sizeof(int))
this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint)) this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint))
@ -337,7 +335,6 @@ cdef cppclass StateC:
this.set_break(this.B_(0).l_edge) this.set_break(this.B_(0).l_edge)
if this._b_i > this._break: if this._b_i > this._break:
this._break = -1 this._break = -1
this.n_pushes += 1
void pop() nogil: void pop() nogil:
if this._s_i >= 1: if this._s_i >= 1:
@ -354,7 +351,6 @@ cdef cppclass StateC:
this._buffer[this._b_i] = this.S(0) this._buffer[this._b_i] = this.S(0)
this._s_i -= 1 this._s_i -= 1
this.shifted[this.B(0)] = True this.shifted[this.B(0)] = True
this.n_pushes -= 1
void add_arc(int head, int child, attr_t label) nogil: void add_arc(int head, int child, attr_t label) nogil:
if this.has_head(child): if this.has_head(child):
@ -435,7 +431,6 @@ cdef cppclass StateC:
this._break = src._break this._break = src._break
this.offset = src.offset this.offset = src.offset
this._empty_token = src._empty_token this._empty_token = src._empty_token
this.n_pushes = src.n_pushes
void fast_forward() nogil: void fast_forward() nogil:
# space token attachement policy: # space token attachement policy:

View File

@ -36,10 +36,6 @@ cdef class StateClass:
hist[i] = self.c.get_hist(i+1) hist[i] = self.c.get_hist(i+1)
return hist return hist
@property
def n_pushes(self):
return self.c.n_pushes
def is_final(self): def is_final(self):
return self.c.is_final() return self.c.is_final()

View File

@ -6,6 +6,7 @@ from itertools import islice
from libcpp.vector cimport vector from libcpp.vector cimport vector
from libc.string cimport memset from libc.string cimport memset
from libc.stdlib cimport calloc, free from libc.stdlib cimport calloc, free
import random
import srsly import srsly
from thinc.api import set_dropout_rate from thinc.api import set_dropout_rate
@ -275,22 +276,22 @@ cdef class Parser(Pipe):
# Prepare the stepwise model, and get the callback for finishing the batch # Prepare the stepwise model, and get the callback for finishing the batch
model, backprop_tok2vec = self.model.begin_update( model, backprop_tok2vec = self.model.begin_update(
[eg.predicted for eg in examples]) [eg.predicted for eg in examples])
if self.cfg["update_with_oracle_cut_size"] >= 1: max_moves = self.cfg["update_with_oracle_cut_size"]
# Chop sequences into lengths of this many transitions, to make the if max_moves >= 1:
# Chop sequences into lengths of this many words, to make the
# batch uniform length. # batch uniform length.
# We used to randomize this, but it's not clear that actually helps? max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
max_pushes = self.cfg["update_with_oracle_cut_size"]
states, golds, _ = self._init_gold_batch( states, golds, _ = self._init_gold_batch(
examples, examples,
max_length=max_pushes max_length=max_moves
) )
else: else:
states, golds, _ = self.moves.init_gold_batch(examples) states, golds, _ = self.moves.init_gold_batch(examples)
max_pushes = max([len(eg.x) for eg in examples])
if not states: if not states:
return losses return losses
all_states = list(states) all_states = list(states)
states_golds = list(zip(states, golds)) states_golds = list(zip(states, golds))
n_moves = 0
while states_golds: while states_golds:
states, golds = zip(*states_golds) states, golds = zip(*states_golds)
scores, backprop = model.begin_update(states) scores, backprop = model.begin_update(states)
@ -302,8 +303,10 @@ cdef class Parser(Pipe):
backprop(d_scores) backprop(d_scores)
# Follow the predicted action # Follow the predicted action
self.transition_states(states, scores) self.transition_states(states, scores)
states_golds = [(s, g) for (s, g) in zip(states, golds) states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()]
if s.n_pushes < max_pushes and not s.is_final()] if max_moves >= 1 and n_moves >= max_moves:
break
n_moves += 1
backprop_tok2vec(golds) backprop_tok2vec(golds)
if sgd not in (None, False): if sgd not in (None, False):
@ -499,7 +502,7 @@ cdef class Parser(Pipe):
raise ValueError(Errors.E149) from None raise ValueError(Errors.E149) from None
return self return self
def _init_gold_batch(self, examples, min_length=5, max_length=500): def _init_gold_batch(self, examples, max_length):
"""Make a square batch, of length equal to the shortest transition """Make a square batch, of length equal to the shortest transition
sequence or a cap. A long sequence or a cap. A long
doc will get multiple states. Let's say we have a doc of length 2*N, doc will get multiple states. Let's say we have a doc of length 2*N,
@ -512,8 +515,7 @@ cdef class Parser(Pipe):
all_states = self.moves.init_batch([eg.predicted for eg in examples]) all_states = self.moves.init_batch([eg.predicted for eg in examples])
states = [] states = []
golds = [] golds = []
kept = [] to_cut = []
max_length_seen = 0
for state, eg in zip(all_states, examples): for state, eg in zip(all_states, examples):
if self.moves.has_gold(eg) and not state.is_final(): if self.moves.has_gold(eg) and not state.is_final():
gold = self.moves.init_gold(state, eg) gold = self.moves.init_gold(state, eg)
@ -523,30 +525,22 @@ cdef class Parser(Pipe):
else: else:
oracle_actions = self.moves.get_oracle_sequence_from_state( oracle_actions = self.moves.get_oracle_sequence_from_state(
state.copy(), gold) state.copy(), gold)
kept.append((eg, state, gold, oracle_actions)) to_cut.append((eg, state, gold, oracle_actions))
min_length = min(min_length, len(oracle_actions)) if not to_cut:
max_length_seen = max(max_length, len(oracle_actions))
if not kept:
return states, golds, 0 return states, golds, 0
max_length = max(min_length, min(max_length, max_length_seen))
cdef int clas cdef int clas
max_moves = 0 for eg, state, gold, oracle_actions in to_cut:
for eg, state, gold, oracle_actions in kept:
for i in range(0, len(oracle_actions), max_length): for i in range(0, len(oracle_actions), max_length):
start_state = state.copy() start_state = state.copy()
n_moves = 0
for clas in oracle_actions[i:i+max_length]: for clas in oracle_actions[i:i+max_length]:
action = self.moves.c[clas] action = self.moves.c[clas]
action.do(state.c, action.label) action.do(state.c, action.label)
state.c.push_hist(action.clas) state.c.push_hist(action.clas)
n_moves += 1
if state.is_final(): if state.is_final():
break break
max_moves = max(max_moves, n_moves)
if self.moves.has_gold(eg, start_state.B(0), state.B(0)): if self.moves.has_gold(eg, start_state.B(0), state.B(0)):
states.append(start_state) states.append(start_state)
golds.append(gold) golds.append(gold)
max_moves = max(max_moves, n_moves)
if state.is_final(): if state.is_final():
break break
return states, golds, max_moves return states, golds, max_length