mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Improve implementation of fix #6010
Follow-ups to the parser efficiency fix. * Avoid introducing new counter for number of pushes * Base cut on number of transitions, keeping it more even * Reintroduce the randomization we had in v2.
This commit is contained in:
parent
eb56377799
commit
737a1408d9
|
@ -42,7 +42,6 @@ cdef cppclass StateC:
|
|||
RingBufferC _hist
|
||||
int length
|
||||
int offset
|
||||
int n_pushes
|
||||
int _s_i
|
||||
int _b_i
|
||||
int _e_i
|
||||
|
@ -50,7 +49,6 @@ cdef cppclass StateC:
|
|||
|
||||
__init__(const TokenC* sent, int length) nogil:
|
||||
cdef int PADDING = 5
|
||||
this.n_pushes = 0
|
||||
this._buffer = <int*>calloc(length + (PADDING * 2), sizeof(int))
|
||||
this._stack = <int*>calloc(length + (PADDING * 2), sizeof(int))
|
||||
this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint))
|
||||
|
@ -337,7 +335,6 @@ cdef cppclass StateC:
|
|||
this.set_break(this.B_(0).l_edge)
|
||||
if this._b_i > this._break:
|
||||
this._break = -1
|
||||
this.n_pushes += 1
|
||||
|
||||
void pop() nogil:
|
||||
if this._s_i >= 1:
|
||||
|
@ -354,7 +351,6 @@ cdef cppclass StateC:
|
|||
this._buffer[this._b_i] = this.S(0)
|
||||
this._s_i -= 1
|
||||
this.shifted[this.B(0)] = True
|
||||
this.n_pushes -= 1
|
||||
|
||||
void add_arc(int head, int child, attr_t label) nogil:
|
||||
if this.has_head(child):
|
||||
|
@ -435,7 +431,6 @@ cdef cppclass StateC:
|
|||
this._break = src._break
|
||||
this.offset = src.offset
|
||||
this._empty_token = src._empty_token
|
||||
this.n_pushes = src.n_pushes
|
||||
|
||||
void fast_forward() nogil:
|
||||
# space token attachement policy:
|
||||
|
|
|
@ -36,10 +36,6 @@ cdef class StateClass:
|
|||
hist[i] = self.c.get_hist(i+1)
|
||||
return hist
|
||||
|
||||
@property
|
||||
def n_pushes(self):
|
||||
return self.c.n_pushes
|
||||
|
||||
def is_final(self):
|
||||
return self.c.is_final()
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ from itertools import islice
|
|||
from libcpp.vector cimport vector
|
||||
from libc.string cimport memset
|
||||
from libc.stdlib cimport calloc, free
|
||||
import random
|
||||
|
||||
import srsly
|
||||
from thinc.api import set_dropout_rate
|
||||
|
@ -275,22 +276,22 @@ cdef class Parser(Pipe):
|
|||
# Prepare the stepwise model, and get the callback for finishing the batch
|
||||
model, backprop_tok2vec = self.model.begin_update(
|
||||
[eg.predicted for eg in examples])
|
||||
if self.cfg["update_with_oracle_cut_size"] >= 1:
|
||||
# Chop sequences into lengths of this many transitions, to make the
|
||||
max_moves = self.cfg["update_with_oracle_cut_size"]
|
||||
if max_moves >= 1:
|
||||
# Chop sequences into lengths of this many words, to make the
|
||||
# batch uniform length.
|
||||
# We used to randomize this, but it's not clear that actually helps?
|
||||
max_pushes = self.cfg["update_with_oracle_cut_size"]
|
||||
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
|
||||
states, golds, _ = self._init_gold_batch(
|
||||
examples,
|
||||
max_length=max_pushes
|
||||
max_length=max_moves
|
||||
)
|
||||
else:
|
||||
states, golds, _ = self.moves.init_gold_batch(examples)
|
||||
max_pushes = max([len(eg.x) for eg in examples])
|
||||
if not states:
|
||||
return losses
|
||||
all_states = list(states)
|
||||
states_golds = list(zip(states, golds))
|
||||
n_moves = 0
|
||||
while states_golds:
|
||||
states, golds = zip(*states_golds)
|
||||
scores, backprop = model.begin_update(states)
|
||||
|
@ -302,8 +303,10 @@ cdef class Parser(Pipe):
|
|||
backprop(d_scores)
|
||||
# Follow the predicted action
|
||||
self.transition_states(states, scores)
|
||||
states_golds = [(s, g) for (s, g) in zip(states, golds)
|
||||
if s.n_pushes < max_pushes and not s.is_final()]
|
||||
states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()]
|
||||
if max_moves >= 1 and n_moves >= max_moves:
|
||||
break
|
||||
n_moves += 1
|
||||
|
||||
backprop_tok2vec(golds)
|
||||
if sgd not in (None, False):
|
||||
|
@ -499,7 +502,7 @@ cdef class Parser(Pipe):
|
|||
raise ValueError(Errors.E149) from None
|
||||
return self
|
||||
|
||||
def _init_gold_batch(self, examples, min_length=5, max_length=500):
|
||||
def _init_gold_batch(self, examples, max_length):
|
||||
"""Make a square batch, of length equal to the shortest transition
|
||||
sequence or a cap. A long
|
||||
doc will get multiple states. Let's say we have a doc of length 2*N,
|
||||
|
@ -512,8 +515,7 @@ cdef class Parser(Pipe):
|
|||
all_states = self.moves.init_batch([eg.predicted for eg in examples])
|
||||
states = []
|
||||
golds = []
|
||||
kept = []
|
||||
max_length_seen = 0
|
||||
to_cut = []
|
||||
for state, eg in zip(all_states, examples):
|
||||
if self.moves.has_gold(eg) and not state.is_final():
|
||||
gold = self.moves.init_gold(state, eg)
|
||||
|
@ -523,30 +525,22 @@ cdef class Parser(Pipe):
|
|||
else:
|
||||
oracle_actions = self.moves.get_oracle_sequence_from_state(
|
||||
state.copy(), gold)
|
||||
kept.append((eg, state, gold, oracle_actions))
|
||||
min_length = min(min_length, len(oracle_actions))
|
||||
max_length_seen = max(max_length, len(oracle_actions))
|
||||
if not kept:
|
||||
to_cut.append((eg, state, gold, oracle_actions))
|
||||
if not to_cut:
|
||||
return states, golds, 0
|
||||
max_length = max(min_length, min(max_length, max_length_seen))
|
||||
cdef int clas
|
||||
max_moves = 0
|
||||
for eg, state, gold, oracle_actions in kept:
|
||||
for eg, state, gold, oracle_actions in to_cut:
|
||||
for i in range(0, len(oracle_actions), max_length):
|
||||
start_state = state.copy()
|
||||
n_moves = 0
|
||||
for clas in oracle_actions[i:i+max_length]:
|
||||
action = self.moves.c[clas]
|
||||
action.do(state.c, action.label)
|
||||
state.c.push_hist(action.clas)
|
||||
n_moves += 1
|
||||
if state.is_final():
|
||||
break
|
||||
max_moves = max(max_moves, n_moves)
|
||||
if self.moves.has_gold(eg, start_state.B(0), state.B(0)):
|
||||
states.append(start_state)
|
||||
golds.append(gold)
|
||||
max_moves = max(max_moves, n_moves)
|
||||
if state.is_final():
|
||||
break
|
||||
return states, golds, max_moves
|
||||
return states, golds, max_length
|
||||
|
|
Loading…
Reference in New Issue
Block a user