mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-11 04:08:09 +03:00
Fix tensor gradient in parser
This commit is contained in:
parent
e420e0366c
commit
a8e4064dd8
|
@ -633,10 +633,9 @@ cdef class Parser:
|
|||
xp = get_array_module(d_tokvecs)
|
||||
for ids, d_vector, bp_vector in backprops:
|
||||
d_state_features = bp_vector(d_vector, sgd=sgd)
|
||||
mask = ids >= 0
|
||||
indices = xp.nonzero(mask)
|
||||
self.model[0].ops.scatter_add(d_tokvecs, ids[indices],
|
||||
d_state_features[indices])
|
||||
mask = (ids >= 0).reshape((ids.shape[0], ids.shape[1], 1))
|
||||
self.model[0].ops.scatter_add(d_tokvecs, ids,
|
||||
d_state_features * mask)
|
||||
|
||||
@property
|
||||
def move_names(self):
|
||||
|
|
Loading…
Reference in New Issue
Block a user