mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 10:46:29 +03:00
Improve implementation of fix #6010
Follow-ups to the parser efficiency fix. * Avoid introducing new counter for number of pushes * Base cut on number of transitions, keeping it more even * Reintroduce the randomization we had in v2.
This commit is contained in:
parent
eb56377799
commit
737a1408d9
|
@ -42,7 +42,6 @@ cdef cppclass StateC:
|
||||||
RingBufferC _hist
|
RingBufferC _hist
|
||||||
int length
|
int length
|
||||||
int offset
|
int offset
|
||||||
int n_pushes
|
|
||||||
int _s_i
|
int _s_i
|
||||||
int _b_i
|
int _b_i
|
||||||
int _e_i
|
int _e_i
|
||||||
|
@ -50,7 +49,6 @@ cdef cppclass StateC:
|
||||||
|
|
||||||
__init__(const TokenC* sent, int length) nogil:
|
__init__(const TokenC* sent, int length) nogil:
|
||||||
cdef int PADDING = 5
|
cdef int PADDING = 5
|
||||||
this.n_pushes = 0
|
|
||||||
this._buffer = <int*>calloc(length + (PADDING * 2), sizeof(int))
|
this._buffer = <int*>calloc(length + (PADDING * 2), sizeof(int))
|
||||||
this._stack = <int*>calloc(length + (PADDING * 2), sizeof(int))
|
this._stack = <int*>calloc(length + (PADDING * 2), sizeof(int))
|
||||||
this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint))
|
this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint))
|
||||||
|
@ -337,7 +335,6 @@ cdef cppclass StateC:
|
||||||
this.set_break(this.B_(0).l_edge)
|
this.set_break(this.B_(0).l_edge)
|
||||||
if this._b_i > this._break:
|
if this._b_i > this._break:
|
||||||
this._break = -1
|
this._break = -1
|
||||||
this.n_pushes += 1
|
|
||||||
|
|
||||||
void pop() nogil:
|
void pop() nogil:
|
||||||
if this._s_i >= 1:
|
if this._s_i >= 1:
|
||||||
|
@ -354,7 +351,6 @@ cdef cppclass StateC:
|
||||||
this._buffer[this._b_i] = this.S(0)
|
this._buffer[this._b_i] = this.S(0)
|
||||||
this._s_i -= 1
|
this._s_i -= 1
|
||||||
this.shifted[this.B(0)] = True
|
this.shifted[this.B(0)] = True
|
||||||
this.n_pushes -= 1
|
|
||||||
|
|
||||||
void add_arc(int head, int child, attr_t label) nogil:
|
void add_arc(int head, int child, attr_t label) nogil:
|
||||||
if this.has_head(child):
|
if this.has_head(child):
|
||||||
|
@ -435,7 +431,6 @@ cdef cppclass StateC:
|
||||||
this._break = src._break
|
this._break = src._break
|
||||||
this.offset = src.offset
|
this.offset = src.offset
|
||||||
this._empty_token = src._empty_token
|
this._empty_token = src._empty_token
|
||||||
this.n_pushes = src.n_pushes
|
|
||||||
|
|
||||||
void fast_forward() nogil:
|
void fast_forward() nogil:
|
||||||
# space token attachement policy:
|
# space token attachement policy:
|
||||||
|
|
|
@ -36,10 +36,6 @@ cdef class StateClass:
|
||||||
hist[i] = self.c.get_hist(i+1)
|
hist[i] = self.c.get_hist(i+1)
|
||||||
return hist
|
return hist
|
||||||
|
|
||||||
@property
|
|
||||||
def n_pushes(self):
|
|
||||||
return self.c.n_pushes
|
|
||||||
|
|
||||||
def is_final(self):
|
def is_final(self):
|
||||||
return self.c.is_final()
|
return self.c.is_final()
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ from itertools import islice
|
||||||
from libcpp.vector cimport vector
|
from libcpp.vector cimport vector
|
||||||
from libc.string cimport memset
|
from libc.string cimport memset
|
||||||
from libc.stdlib cimport calloc, free
|
from libc.stdlib cimport calloc, free
|
||||||
|
import random
|
||||||
|
|
||||||
import srsly
|
import srsly
|
||||||
from thinc.api import set_dropout_rate
|
from thinc.api import set_dropout_rate
|
||||||
|
@ -275,22 +276,22 @@ cdef class Parser(Pipe):
|
||||||
# Prepare the stepwise model, and get the callback for finishing the batch
|
# Prepare the stepwise model, and get the callback for finishing the batch
|
||||||
model, backprop_tok2vec = self.model.begin_update(
|
model, backprop_tok2vec = self.model.begin_update(
|
||||||
[eg.predicted for eg in examples])
|
[eg.predicted for eg in examples])
|
||||||
if self.cfg["update_with_oracle_cut_size"] >= 1:
|
max_moves = self.cfg["update_with_oracle_cut_size"]
|
||||||
# Chop sequences into lengths of this many transitions, to make the
|
if max_moves >= 1:
|
||||||
|
# Chop sequences into lengths of this many words, to make the
|
||||||
# batch uniform length.
|
# batch uniform length.
|
||||||
# We used to randomize this, but it's not clear that actually helps?
|
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
|
||||||
max_pushes = self.cfg["update_with_oracle_cut_size"]
|
|
||||||
states, golds, _ = self._init_gold_batch(
|
states, golds, _ = self._init_gold_batch(
|
||||||
examples,
|
examples,
|
||||||
max_length=max_pushes
|
max_length=max_moves
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
states, golds, _ = self.moves.init_gold_batch(examples)
|
states, golds, _ = self.moves.init_gold_batch(examples)
|
||||||
max_pushes = max([len(eg.x) for eg in examples])
|
|
||||||
if not states:
|
if not states:
|
||||||
return losses
|
return losses
|
||||||
all_states = list(states)
|
all_states = list(states)
|
||||||
states_golds = list(zip(states, golds))
|
states_golds = list(zip(states, golds))
|
||||||
|
n_moves = 0
|
||||||
while states_golds:
|
while states_golds:
|
||||||
states, golds = zip(*states_golds)
|
states, golds = zip(*states_golds)
|
||||||
scores, backprop = model.begin_update(states)
|
scores, backprop = model.begin_update(states)
|
||||||
|
@ -302,8 +303,10 @@ cdef class Parser(Pipe):
|
||||||
backprop(d_scores)
|
backprop(d_scores)
|
||||||
# Follow the predicted action
|
# Follow the predicted action
|
||||||
self.transition_states(states, scores)
|
self.transition_states(states, scores)
|
||||||
states_golds = [(s, g) for (s, g) in zip(states, golds)
|
states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()]
|
||||||
if s.n_pushes < max_pushes and not s.is_final()]
|
if max_moves >= 1 and n_moves >= max_moves:
|
||||||
|
break
|
||||||
|
n_moves += 1
|
||||||
|
|
||||||
backprop_tok2vec(golds)
|
backprop_tok2vec(golds)
|
||||||
if sgd not in (None, False):
|
if sgd not in (None, False):
|
||||||
|
@ -499,7 +502,7 @@ cdef class Parser(Pipe):
|
||||||
raise ValueError(Errors.E149) from None
|
raise ValueError(Errors.E149) from None
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _init_gold_batch(self, examples, min_length=5, max_length=500):
|
def _init_gold_batch(self, examples, max_length):
|
||||||
"""Make a square batch, of length equal to the shortest transition
|
"""Make a square batch, of length equal to the shortest transition
|
||||||
sequence or a cap. A long
|
sequence or a cap. A long
|
||||||
doc will get multiple states. Let's say we have a doc of length 2*N,
|
doc will get multiple states. Let's say we have a doc of length 2*N,
|
||||||
|
@ -512,8 +515,7 @@ cdef class Parser(Pipe):
|
||||||
all_states = self.moves.init_batch([eg.predicted for eg in examples])
|
all_states = self.moves.init_batch([eg.predicted for eg in examples])
|
||||||
states = []
|
states = []
|
||||||
golds = []
|
golds = []
|
||||||
kept = []
|
to_cut = []
|
||||||
max_length_seen = 0
|
|
||||||
for state, eg in zip(all_states, examples):
|
for state, eg in zip(all_states, examples):
|
||||||
if self.moves.has_gold(eg) and not state.is_final():
|
if self.moves.has_gold(eg) and not state.is_final():
|
||||||
gold = self.moves.init_gold(state, eg)
|
gold = self.moves.init_gold(state, eg)
|
||||||
|
@ -523,30 +525,22 @@ cdef class Parser(Pipe):
|
||||||
else:
|
else:
|
||||||
oracle_actions = self.moves.get_oracle_sequence_from_state(
|
oracle_actions = self.moves.get_oracle_sequence_from_state(
|
||||||
state.copy(), gold)
|
state.copy(), gold)
|
||||||
kept.append((eg, state, gold, oracle_actions))
|
to_cut.append((eg, state, gold, oracle_actions))
|
||||||
min_length = min(min_length, len(oracle_actions))
|
if not to_cut:
|
||||||
max_length_seen = max(max_length, len(oracle_actions))
|
|
||||||
if not kept:
|
|
||||||
return states, golds, 0
|
return states, golds, 0
|
||||||
max_length = max(min_length, min(max_length, max_length_seen))
|
|
||||||
cdef int clas
|
cdef int clas
|
||||||
max_moves = 0
|
for eg, state, gold, oracle_actions in to_cut:
|
||||||
for eg, state, gold, oracle_actions in kept:
|
|
||||||
for i in range(0, len(oracle_actions), max_length):
|
for i in range(0, len(oracle_actions), max_length):
|
||||||
start_state = state.copy()
|
start_state = state.copy()
|
||||||
n_moves = 0
|
|
||||||
for clas in oracle_actions[i:i+max_length]:
|
for clas in oracle_actions[i:i+max_length]:
|
||||||
action = self.moves.c[clas]
|
action = self.moves.c[clas]
|
||||||
action.do(state.c, action.label)
|
action.do(state.c, action.label)
|
||||||
state.c.push_hist(action.clas)
|
state.c.push_hist(action.clas)
|
||||||
n_moves += 1
|
|
||||||
if state.is_final():
|
if state.is_final():
|
||||||
break
|
break
|
||||||
max_moves = max(max_moves, n_moves)
|
|
||||||
if self.moves.has_gold(eg, start_state.B(0), state.B(0)):
|
if self.moves.has_gold(eg, start_state.B(0), state.B(0)):
|
||||||
states.append(start_state)
|
states.append(start_state)
|
||||||
golds.append(gold)
|
golds.append(gold)
|
||||||
max_moves = max(max_moves, n_moves)
|
|
||||||
if state.is_final():
|
if state.is_final():
|
||||||
break
|
break
|
||||||
return states, golds, max_moves
|
return states, golds, max_length
|
||||||
|
|
Loading…
Reference in New Issue
Block a user