Fix type of mask

The call here was creating a float64 array, which was turning many
downstream scores into float64s. Later on these values were assigned to
a float32 array in backprop, and numerical underflow caused things to go
to zero.

That's almost certainly not the only reason things go to zero, but it is
incorrect.
This commit is contained in:
Paul O'Leary McCann 2021-06-17 17:56:00 +09:00
parent 8452d117ef
commit cb2364cf83

View File

@ -373,7 +373,7 @@ def ant_scorer_forward(
warnings.filterwarnings('ignore', category=RuntimeWarning)
mask = xp.log(
(xp.expand_dims(ant_range, 1) - xp.expand_dims(ant_range, 0)) >= 1
).astype(float)
).astype('f')
scores = pw_prod + pw_sum + mask