mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-11 12:18:04 +03:00
Fix tensor gradient in parser
This commit is contained in:
parent
e420e0366c
commit
a8e4064dd8
|
@ -633,10 +633,9 @@ cdef class Parser:
|
||||||
xp = get_array_module(d_tokvecs)
|
xp = get_array_module(d_tokvecs)
|
||||||
for ids, d_vector, bp_vector in backprops:
|
for ids, d_vector, bp_vector in backprops:
|
||||||
d_state_features = bp_vector(d_vector, sgd=sgd)
|
d_state_features = bp_vector(d_vector, sgd=sgd)
|
||||||
mask = ids >= 0
|
mask = (ids >= 0).reshape((ids.shape[0], ids.shape[1], 1))
|
||||||
indices = xp.nonzero(mask)
|
self.model[0].ops.scatter_add(d_tokvecs, ids,
|
||||||
self.model[0].ops.scatter_add(d_tokvecs, ids[indices],
|
d_state_features * mask)
|
||||||
d_state_features[indices])
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def move_names(self):
|
def move_names(self):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user