mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-04 20:30:24 +03:00
mypy fixes
This commit is contained in:
parent
ca6aa239bc
commit
6d32ae01da
|
@ -1,4 +1,4 @@
|
||||||
from typing import List, Tuple, Any, Optional
|
from typing import List, Tuple, Any, Optional, cast
|
||||||
from thinc.api import Ops, Model, normal_init, chain, list2array, Linear
|
from thinc.api import Ops, Model, normal_init, chain, list2array, Linear
|
||||||
from thinc.api import uniform_init, glorot_uniform_init, zero_init
|
from thinc.api import uniform_init, glorot_uniform_init, zero_init
|
||||||
from thinc.types import Floats1d, Floats2d, Floats3d, Ints2d, Floats4d
|
from thinc.types import Floats1d, Floats2d, Floats3d, Ints2d, Floats4d
|
||||||
|
@ -399,10 +399,10 @@ def _lsuv_init(model: Model):
|
||||||
model.set_param("b", b)
|
model.set_param("b", b)
|
||||||
model.set_param("pad", pad)
|
model.set_param("pad", pad)
|
||||||
|
|
||||||
ids = ops.alloc((5000, nF), dtype="f")
|
ids = ops.alloc_f((5000, nF), dtype="f")
|
||||||
ids += ops.xp.random.uniform(0, 1000, ids.shape)
|
ids += ops.xp.random.uniform(0, 1000, ids.shape)
|
||||||
ids = ops.asarray(ids, dtype="i")
|
ids = ops.asarray(ids, dtype="i")
|
||||||
tokvecs = ops.alloc((5000, nI), dtype="f")
|
tokvecs = ops.alloc_f((5000, nI), dtype="f")
|
||||||
tokvecs += ops.xp.random.normal(loc=0.0, scale=1.0, size=tokvecs.size).reshape(
|
tokvecs += ops.xp.random.normal(loc=0.0, scale=1.0, size=tokvecs.size).reshape(
|
||||||
tokvecs.shape
|
tokvecs.shape
|
||||||
)
|
)
|
||||||
|
@ -421,8 +421,8 @@ def _lsuv_init(model: Model):
|
||||||
tol_var = 0.01
|
tol_var = 0.01
|
||||||
tol_mean = 0.01
|
tol_mean = 0.01
|
||||||
t_max = 10
|
t_max = 10
|
||||||
W = model.get_param("lower_W").copy()
|
W = cast(Floats4d, model.get_param("lower_W").copy())
|
||||||
b = model.get_param("lower_b").copy()
|
b = cast(Floats2d, model.get_param("lower_b").copy())
|
||||||
for t_i in range(t_max):
|
for t_i in range(t_max):
|
||||||
acts1 = predict(ids, tokvecs)
|
acts1 = predict(ids, tokvecs)
|
||||||
var = model.ops.xp.var(acts1)
|
var = model.ops.xp.var(acts1)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user