mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-11 04:08:09 +03:00
Let beam forward use minibatches
This commit is contained in:
parent
855872f872
commit
d274d3a3b9
|
@ -17,7 +17,7 @@ from cpython.ref cimport PyObject, Py_XDECREF
|
|||
from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno
|
||||
from libc.math cimport exp
|
||||
from libcpp.vector cimport vector
|
||||
from libc.string cimport memset
|
||||
from libc.string cimport memset, memcpy
|
||||
from libc.stdlib cimport calloc, free
|
||||
from cymem.cymem cimport Pool
|
||||
from thinc.typedefs cimport weight_t, class_t, hash_t
|
||||
|
@ -485,14 +485,14 @@ cdef class Parser:
|
|||
cdef np.ndarray scores
|
||||
cdef Doc doc
|
||||
cdef int nr_class = self.moves.n_moves
|
||||
cdef StateClass stcls, output
|
||||
cuda_stream = util.get_cuda_stream()
|
||||
(tokvecs, bp_tokvecs), state2vec, vec2scores = self.get_batch_model(
|
||||
docs, cuda_stream, 0.0)
|
||||
beams = []
|
||||
cdef int offset = 0
|
||||
cdef int j = 0
|
||||
cdef int k
|
||||
|
||||
beams = []
|
||||
for doc in docs:
|
||||
beam = Beam(nr_class, beam_width, min_density=beam_density)
|
||||
beam.initialize(self.moves.init_beam_state, doc.length, doc.c)
|
||||
|
@ -501,34 +501,43 @@ cdef class Parser:
|
|||
state.offset = offset
|
||||
offset += len(doc)
|
||||
beam.check_done(_check_final_state, NULL)
|
||||
while not beam.is_done:
|
||||
states = []
|
||||
beams.append(beam)
|
||||
cdef np.ndarray token_ids
|
||||
token_ids = numpy.zeros((len(docs) * beam_width, self.nr_feature),
|
||||
dtype='i', order='C')
|
||||
todo = [beam for beam in beams if not beam.is_done]
|
||||
|
||||
cdef int* c_ids
|
||||
cdef int nr_feature = self.nr_feature
|
||||
cdef int n_states
|
||||
while todo:
|
||||
todo = [beam for beam in beams if not beam.is_done]
|
||||
token_ids.fill(-1)
|
||||
c_ids = <int*>token_ids.data
|
||||
n_states = 0
|
||||
for beam in todo:
|
||||
for i in range(beam.size):
|
||||
stcls = StateClass.borrow(<StateC*>beam.at(i))
|
||||
state = <StateC*>beam.at(i)
|
||||
# This way we avoid having to score finalized states
|
||||
# We do have to take care to keep indexes aligned, though
|
||||
if not stcls.is_final():
|
||||
states.append(stcls)
|
||||
token_ids = self.get_token_ids(states)
|
||||
vectors = state2vec(token_ids)
|
||||
if self.cfg.get('hist_size', 0):
|
||||
hists = numpy.asarray([st.history[:self.cfg['hist_size']]
|
||||
for st in states], dtype='i')
|
||||
scores = vec2scores((vectors, hists))
|
||||
else:
|
||||
scores = vec2scores(vectors)
|
||||
j = 0
|
||||
c_scores = <float*>scores.data
|
||||
if not state.is_final():
|
||||
state.set_context_tokens(c_ids, nr_feature)
|
||||
c_ids += nr_feature
|
||||
n_states += 1
|
||||
if n_states == 0:
|
||||
break
|
||||
vectors = state2vec(token_ids[:n_states])
|
||||
scores = vec2scores(vectors)
|
||||
c_scores = <float*>scores.data
|
||||
for beam in todo:
|
||||
for i in range(beam.size):
|
||||
state = <StateC*>beam.at(i)
|
||||
if not state.is_final():
|
||||
self.moves.set_valid(beam.is_valid[i], state)
|
||||
for k in range(nr_class):
|
||||
beam.scores[i][k] = c_scores[j * scores.shape[1] + k]
|
||||
j += 1
|
||||
memcpy(beam.scores[i], c_scores, nr_class * sizeof(float))
|
||||
c_scores += nr_class
|
||||
beam.advance(_transition_state, NULL, <void*>self.moves.c)
|
||||
beam.check_done(_check_final_state, NULL)
|
||||
beams.append(beam)
|
||||
tokvecs = self.model[0].ops.unflatten(tokvecs,
|
||||
[len(doc) for doc in docs])
|
||||
return beams, tokvecs
|
||||
|
@ -536,7 +545,7 @@ cdef class Parser:
|
|||
def update(self, docs, golds, drop=0., sgd=None, losses=None):
|
||||
if not any(self.moves.has_gold(gold) for gold in golds):
|
||||
return None
|
||||
if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() >= 0.5:
|
||||
if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() >= 0.0:
|
||||
return self.update_beam(docs, golds,
|
||||
self.cfg['beam_width'], self.cfg['beam_density'],
|
||||
drop=drop, sgd=sgd, losses=losses)
|
||||
|
|
Loading…
Reference in New Issue
Block a user