mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-11 04:08:09 +03:00
Fix GPU usage
This commit is contained in:
parent
e98451b5f7
commit
7698903617
|
@ -89,11 +89,14 @@ cdef class precompute_hiddens:
|
|||
cached = gpu_cached.get(stream=cuda_stream)
|
||||
else:
|
||||
cached = gpu_cached
|
||||
if not isinstance(lower_model.b, numpy.ndarray):
|
||||
self.bias = lower_model.b.get()
|
||||
else:
|
||||
self.bias = lower_model.b
|
||||
self.nF = cached.shape[1]
|
||||
self.nP = getattr(lower_model, 'nP', 1)
|
||||
self.nO = cached.shape[2]
|
||||
self.ops = lower_model.ops
|
||||
self.bias = lower_model.b
|
||||
self._is_synchronized = False
|
||||
self._cuda_stream = cuda_stream
|
||||
self._cached = cached
|
||||
|
@ -126,7 +129,8 @@ cdef class precompute_hiddens:
|
|||
state_vector += self.bias
|
||||
state_vector, bp_nonlinearity = self._nonlinearity(state_vector)
|
||||
|
||||
def backward(d_state_vector, sgd=None):
|
||||
def backward(d_state_vector_ids, sgd=None):
|
||||
d_state_vector, token_ids = d_state_vector_ids
|
||||
d_state_vector = bp_nonlinearity(d_state_vector, sgd)
|
||||
# This will usually be on GPU
|
||||
if not isinstance(d_state_vector, self.ops.xp.ndarray):
|
||||
|
@ -157,7 +161,8 @@ cdef void sum_state_features(float* output,
|
|||
const float* cached, const int* token_ids, int B, int F, int O) nogil:
|
||||
cdef int idx, b, f, i
|
||||
cdef const float* feature
|
||||
padding = cached - (F * O)
|
||||
padding = cached
|
||||
cached += F * O
|
||||
for b in range(B):
|
||||
for f in range(F):
|
||||
if token_ids[f] < 0:
|
||||
|
@ -657,7 +662,7 @@ cdef class Parser:
|
|||
cuda_stream.synchronize()
|
||||
xp = get_array_module(d_tokvecs)
|
||||
for ids, d_vector, bp_vector in backprops:
|
||||
d_state_features = bp_vector(d_vector, sgd=sgd)
|
||||
d_state_features = bp_vector((d_vector, ids), sgd=sgd)
|
||||
ids = ids.flatten()
|
||||
d_state_features = d_state_features.reshape(
|
||||
(ids.size, d_state_features.shape[2]))
|
||||
|
|
Loading…
Reference in New Issue
Block a user