mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Support dropout in beam parse
This commit is contained in:
parent
31b156d60b
commit
3d6487c734
|
@ -477,14 +477,15 @@ cdef class Parser:
|
|||
free(vectors)
|
||||
free(scores)
|
||||
|
||||
def beam_parse(self, docs, int beam_width=3, float beam_density=0.001):
|
||||
def beam_parse(self, docs, int beam_width=3, float beam_density=0.001,
|
||||
float drop=0.):
|
||||
cdef Beam beam
|
||||
cdef np.ndarray scores
|
||||
cdef Doc doc
|
||||
cdef int nr_class = self.moves.n_moves
|
||||
cuda_stream = util.get_cuda_stream()
|
||||
(tokvecs, bp_tokvecs), state2vec, vec2scores = self.get_batch_model(
|
||||
docs, cuda_stream, 0.0)
|
||||
docs, cuda_stream, drop)
|
||||
cdef int offset = 0
|
||||
cdef int j = 0
|
||||
cdef int k
|
||||
|
@ -523,8 +524,8 @@ cdef class Parser:
|
|||
n_states += 1
|
||||
if n_states == 0:
|
||||
break
|
||||
vectors = state2vec(token_ids[:n_states])
|
||||
scores = vec2scores(vectors)
|
||||
vectors, _ = state2vec.begin_update(token_ids[:n_states], drop)
|
||||
scores, _ = vec2scores(vectors, drop)
|
||||
c_scores = <float*>scores.data
|
||||
for beam in todo:
|
||||
for i in range(beam.size):
|
||||
|
|
Loading…
Reference in New Issue
Block a user