mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-04 20:30:24 +03:00
mypy fixes
This commit is contained in:
parent
ca6aa239bc
commit
6d32ae01da
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Tuple, Any, Optional
|
||||
from typing import List, Tuple, Any, Optional, cast
|
||||
from thinc.api import Ops, Model, normal_init, chain, list2array, Linear
|
||||
from thinc.api import uniform_init, glorot_uniform_init, zero_init
|
||||
from thinc.types import Floats1d, Floats2d, Floats3d, Ints2d, Floats4d
|
||||
|
@ -399,10 +399,10 @@ def _lsuv_init(model: Model):
|
|||
model.set_param("b", b)
|
||||
model.set_param("pad", pad)
|
||||
|
||||
ids = ops.alloc((5000, nF), dtype="f")
|
||||
ids = ops.alloc_f((5000, nF), dtype="f")
|
||||
ids += ops.xp.random.uniform(0, 1000, ids.shape)
|
||||
ids = ops.asarray(ids, dtype="i")
|
||||
tokvecs = ops.alloc((5000, nI), dtype="f")
|
||||
tokvecs = ops.alloc_f((5000, nI), dtype="f")
|
||||
tokvecs += ops.xp.random.normal(loc=0.0, scale=1.0, size=tokvecs.size).reshape(
|
||||
tokvecs.shape
|
||||
)
|
||||
|
@ -421,8 +421,8 @@ def _lsuv_init(model: Model):
|
|||
tol_var = 0.01
|
||||
tol_mean = 0.01
|
||||
t_max = 10
|
||||
W = model.get_param("lower_W").copy()
|
||||
b = model.get_param("lower_b").copy()
|
||||
W = cast(Floats4d, model.get_param("lower_W").copy())
|
||||
b = cast(Floats2d, model.get_param("lower_b").copy())
|
||||
for t_i in range(t_max):
|
||||
acts1 = predict(ids, tokvecs)
|
||||
var = model.ops.xp.var(acts1)
|
||||
|
|
Loading…
Reference in New Issue
Block a user