mypy fixes

This commit is contained in:
svlandeg 2022-01-20 17:50:37 +01:00
parent ca6aa239bc
commit 6d32ae01da

View File

@ -1,4 +1,4 @@
from typing import List, Tuple, Any, Optional from typing import List, Tuple, Any, Optional, cast
from thinc.api import Ops, Model, normal_init, chain, list2array, Linear from thinc.api import Ops, Model, normal_init, chain, list2array, Linear
from thinc.api import uniform_init, glorot_uniform_init, zero_init from thinc.api import uniform_init, glorot_uniform_init, zero_init
from thinc.types import Floats1d, Floats2d, Floats3d, Ints2d, Floats4d from thinc.types import Floats1d, Floats2d, Floats3d, Ints2d, Floats4d
@ -399,10 +399,10 @@ def _lsuv_init(model: Model):
model.set_param("b", b) model.set_param("b", b)
model.set_param("pad", pad) model.set_param("pad", pad)
ids = ops.alloc((5000, nF), dtype="f") ids = ops.alloc_f((5000, nF), dtype="f")
ids += ops.xp.random.uniform(0, 1000, ids.shape) ids += ops.xp.random.uniform(0, 1000, ids.shape)
ids = ops.asarray(ids, dtype="i") ids = ops.asarray(ids, dtype="i")
tokvecs = ops.alloc((5000, nI), dtype="f") tokvecs = ops.alloc_f((5000, nI), dtype="f")
tokvecs += ops.xp.random.normal(loc=0.0, scale=1.0, size=tokvecs.size).reshape( tokvecs += ops.xp.random.normal(loc=0.0, scale=1.0, size=tokvecs.size).reshape(
tokvecs.shape tokvecs.shape
) )
@ -421,8 +421,8 @@ def _lsuv_init(model: Model):
tol_var = 0.01 tol_var = 0.01
tol_mean = 0.01 tol_mean = 0.01
t_max = 10 t_max = 10
W = model.get_param("lower_W").copy() W = cast(Floats4d, model.get_param("lower_W").copy())
b = model.get_param("lower_b").copy() b = cast(Floats2d, model.get_param("lower_b").copy())
for t_i in range(t_max): for t_i in range(t_max):
acts1 = predict(ids, tokvecs) acts1 = predict(ids, tokvecs)
var = model.ops.xp.var(acts1) var = model.ops.xp.var(acts1)