From 6d32ae01daeba20b3d3f1abab8ce8906aac34e3a Mon Sep 17 00:00:00 2001 From: svlandeg Date: Thu, 20 Jan 2022 17:50:37 +0100 Subject: [PATCH] mypy fixes --- spacy/ml/tb_framework.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/spacy/ml/tb_framework.py b/spacy/ml/tb_framework.py index 753c99cb9..9aac5b801 100644 --- a/spacy/ml/tb_framework.py +++ b/spacy/ml/tb_framework.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Any, Optional +from typing import List, Tuple, Any, Optional, cast from thinc.api import Ops, Model, normal_init, chain, list2array, Linear from thinc.api import uniform_init, glorot_uniform_init, zero_init from thinc.types import Floats1d, Floats2d, Floats3d, Ints2d, Floats4d @@ -399,10 +399,10 @@ def _lsuv_init(model: Model): model.set_param("b", b) model.set_param("pad", pad) - ids = ops.alloc((5000, nF), dtype="f") + ids = ops.alloc_f((5000, nF), dtype="f") ids += ops.xp.random.uniform(0, 1000, ids.shape) ids = ops.asarray(ids, dtype="i") - tokvecs = ops.alloc((5000, nI), dtype="f") + tokvecs = ops.alloc_f((5000, nI), dtype="f") tokvecs += ops.xp.random.normal(loc=0.0, scale=1.0, size=tokvecs.size).reshape( tokvecs.shape ) @@ -421,8 +421,8 @@ def _lsuv_init(model: Model): tol_var = 0.01 tol_mean = 0.01 t_max = 10 - W = model.get_param("lower_W").copy() - b = model.get_param("lower_b").copy() + W = cast(Floats4d, model.get_param("lower_W").copy()) + b = cast(Floats2d, model.get_param("lower_b").copy()) for t_i in range(t_max): acts1 = predict(ids, tokvecs) var = model.ops.xp.var(acts1)