Improve implementation of fix #6010

Follow-ups to the parser efficiency fix.

* Avoid introducing new counter for number of pushes
* Base cut on number of transitions, keeping it more even
* Reintroduce the randomization we had in v2.
This commit is contained in:
Matthew Honnibal 2020-09-02 14:42:32 +02:00
parent eb56377799
commit 737a1408d9
3 changed files with 17 additions and 32 deletions

View File

@ -42,7 +42,6 @@ cdef cppclass StateC:
RingBufferC _hist
int length
int offset
int n_pushes
int _s_i
int _b_i
int _e_i
@ -50,7 +49,6 @@ cdef cppclass StateC:
__init__(const TokenC* sent, int length) nogil:
cdef int PADDING = 5
this.n_pushes = 0
this._buffer = <int*>calloc(length + (PADDING * 2), sizeof(int))
this._stack = <int*>calloc(length + (PADDING * 2), sizeof(int))
this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint))
@ -337,7 +335,6 @@ cdef cppclass StateC:
this.set_break(this.B_(0).l_edge)
if this._b_i > this._break:
this._break = -1
this.n_pushes += 1
void pop() nogil:
if this._s_i >= 1:
@ -354,7 +351,6 @@ cdef cppclass StateC:
this._buffer[this._b_i] = this.S(0)
this._s_i -= 1
this.shifted[this.B(0)] = True
this.n_pushes -= 1
void add_arc(int head, int child, attr_t label) nogil:
if this.has_head(child):
@ -435,7 +431,6 @@ cdef cppclass StateC:
this._break = src._break
this.offset = src.offset
this._empty_token = src._empty_token
this.n_pushes = src.n_pushes
void fast_forward() nogil:
# space token attachement policy:

View File

@ -36,10 +36,6 @@ cdef class StateClass:
hist[i] = self.c.get_hist(i+1)
return hist
@property
def n_pushes(self):
return self.c.n_pushes
def is_final(self):
return self.c.is_final()

View File

@ -6,6 +6,7 @@ from itertools import islice
from libcpp.vector cimport vector
from libc.string cimport memset
from libc.stdlib cimport calloc, free
import random
import srsly
from thinc.api import set_dropout_rate
@ -275,22 +276,22 @@ cdef class Parser(Pipe):
# Prepare the stepwise model, and get the callback for finishing the batch
model, backprop_tok2vec = self.model.begin_update(
[eg.predicted for eg in examples])
if self.cfg["update_with_oracle_cut_size"] >= 1:
# Chop sequences into lengths of this many transitions, to make the
max_moves = self.cfg["update_with_oracle_cut_size"]
if max_moves >= 1:
# Chop sequences into lengths of this many words, to make the
# batch uniform length.
# We used to randomize this, but it's not clear that actually helps?
max_pushes = self.cfg["update_with_oracle_cut_size"]
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
states, golds, _ = self._init_gold_batch(
examples,
max_length=max_pushes
max_length=max_moves
)
else:
states, golds, _ = self.moves.init_gold_batch(examples)
max_pushes = max([len(eg.x) for eg in examples])
if not states:
return losses
all_states = list(states)
states_golds = list(zip(states, golds))
n_moves = 0
while states_golds:
states, golds = zip(*states_golds)
scores, backprop = model.begin_update(states)
@ -302,8 +303,10 @@ cdef class Parser(Pipe):
backprop(d_scores)
# Follow the predicted action
self.transition_states(states, scores)
states_golds = [(s, g) for (s, g) in zip(states, golds)
if s.n_pushes < max_pushes and not s.is_final()]
states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()]
if max_moves >= 1 and n_moves >= max_moves:
break
n_moves += 1
backprop_tok2vec(golds)
if sgd not in (None, False):
@ -499,7 +502,7 @@ cdef class Parser(Pipe):
raise ValueError(Errors.E149) from None
return self
def _init_gold_batch(self, examples, min_length=5, max_length=500):
def _init_gold_batch(self, examples, max_length):
"""Make a square batch, of length equal to the shortest transition
sequence or a cap. A long
doc will get multiple states. Let's say we have a doc of length 2*N,
@ -512,8 +515,7 @@ cdef class Parser(Pipe):
all_states = self.moves.init_batch([eg.predicted for eg in examples])
states = []
golds = []
kept = []
max_length_seen = 0
to_cut = []
for state, eg in zip(all_states, examples):
if self.moves.has_gold(eg) and not state.is_final():
gold = self.moves.init_gold(state, eg)
@ -523,30 +525,22 @@ cdef class Parser(Pipe):
else:
oracle_actions = self.moves.get_oracle_sequence_from_state(
state.copy(), gold)
kept.append((eg, state, gold, oracle_actions))
min_length = min(min_length, len(oracle_actions))
max_length_seen = max(max_length, len(oracle_actions))
if not kept:
to_cut.append((eg, state, gold, oracle_actions))
if not to_cut:
return states, golds, 0
max_length = max(min_length, min(max_length, max_length_seen))
cdef int clas
max_moves = 0
for eg, state, gold, oracle_actions in kept:
for eg, state, gold, oracle_actions in to_cut:
for i in range(0, len(oracle_actions), max_length):
start_state = state.copy()
n_moves = 0
for clas in oracle_actions[i:i+max_length]:
action = self.moves.c[clas]
action.do(state.c, action.label)
state.c.push_hist(action.clas)
n_moves += 1
if state.is_final():
break
max_moves = max(max_moves, n_moves)
if self.moves.has_gold(eg, start_state.B(0), state.B(0)):
states.append(start_state)
golds.append(gold)
max_moves = max(max_moves, n_moves)
if state.is_final():
break
return states, golds, max_moves
return states, golds, max_length